@@ -156,7 +156,7 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
156
156
return lookupOperator (decl, decl->getASTContext ().Id_EqualsOperator , isValid);
157
157
}
158
158
159
- static ValueDecl *getMinusOperator (NominalTypeDecl *decl) {
159
+ static FuncDecl *getMinusOperator (NominalTypeDecl *decl) {
160
160
auto binaryIntegerProto =
161
161
decl->getASTContext ().getProtocol (KnownProtocolKind::BinaryInteger);
162
162
auto module = decl->getModuleContext ();
@@ -188,8 +188,9 @@ static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
188
188
return true ;
189
189
};
190
190
191
- return lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ),
192
- isValid);
191
+ ValueDecl *result =
192
+ lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ), isValid);
193
+ return dyn_cast_or_null<FuncDecl>(result);
193
194
}
194
195
195
196
static ValueDecl *getPlusEqualOperator (NominalTypeDecl *decl, Type distanceTy) {
@@ -223,10 +224,10 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
223
224
isValid);
224
225
}
225
226
226
- static void instantiateTemplatedOperator (
227
- ClangImporter::Implementation &impl,
228
- const clang::ClassTemplateSpecializationDecl *classDecl,
229
- clang::BinaryOperatorKind operatorKind) {
227
+ static clang::FunctionDecl *
228
+ instantiateTemplatedOperator ( ClangImporter::Implementation &impl,
229
+ const clang::CXXRecordDecl *classDecl,
230
+ clang::BinaryOperatorKind operatorKind) {
230
231
231
232
clang::ASTContext &clangCtx = impl.getClangASTContext ();
232
233
clang::Sema &clangSema = impl.getClangSema ();
@@ -252,6 +253,7 @@ static void instantiateTemplatedOperator(
252
253
if (auto clangCallee = best->Function ) {
253
254
auto lookupTable = impl.findLookupTable (classDecl);
254
255
addEntryToLookupTable (*lookupTable, clangCallee, impl.getNameImporter ());
256
+ return clangCallee;
255
257
}
256
258
break ;
257
259
}
@@ -260,6 +262,82 @@ static void instantiateTemplatedOperator(
260
262
case clang::OR_Deleted:
261
263
break ;
262
264
}
265
+
266
+ return nullptr ;
267
+ }
268
+
269
+ // / Warning: This function emits an error and stops compilation if the
270
+ // / underlying operator function is unavailable in Swift for the current target
271
+ // / (see `clang::Sema::DiagnoseAvailabilityOfDecl`).
272
+ static bool makeOperatorFunc (ClangImporter::Implementation &impl,
273
+ const clang::CXXRecordDecl *classDecl,
274
+ clang::BinaryOperatorKind operatorKind) {
275
+ auto &clangCtx = impl.getClangASTContext ();
276
+ auto &clangSema = impl.getClangSema ();
277
+ auto classTy = clangCtx.getRecordType (classDecl);
278
+ auto classTyInfo = clangCtx.getTrivialTypeSourceInfo (classTy);
279
+
280
+ clang::OverloadedOperatorKind opKind =
281
+ clang::BinaryOperator::getOverloadedOperator (operatorKind);
282
+ const char *opSpelling = clang::getOperatorSpelling (opKind);
283
+
284
+ auto declName = clang::DeclarationName (&clangCtx.Idents .get (opSpelling));
285
+ auto declContext =
286
+ const_cast <clang::CXXRecordDecl *>(classDecl)->getDeclContext ();
287
+ auto equalEqualTy =
288
+ clangCtx.getFunctionType (clangCtx.BoolTy , {classTy, classTy},
289
+ clang::FunctionProtoType::ExtProtoInfo ());
290
+
291
+ // Create a `bool operator==(T, T)` function.
292
+ auto equalEqualDecl = clang::FunctionDecl::Create (
293
+ clangCtx, declContext, clang::SourceLocation (), clang::SourceLocation (),
294
+ declName, equalEqualTy,
295
+ clangCtx.getTrivialTypeSourceInfo (clangCtx.BoolTy ),
296
+ clang::StorageClass::SC_Static);
297
+ equalEqualDecl->setImplicit ();
298
+ equalEqualDecl->setImplicitlyInline ();
299
+ equalEqualDecl->setAccess (clang::AccessSpecifier::AS_public);
300
+
301
+ // Create the parameters of the function. They are not referenced from source
302
+ // code, so they don't need to have a name.
303
+ auto lhsParamId = nullptr ;
304
+ auto lhsParamDecl = clang::ParmVarDecl::Create (
305
+ clangCtx, equalEqualDecl, clang::SourceLocation (),
306
+ clang::SourceLocation (), lhsParamId, classTy, classTyInfo,
307
+ clang::StorageClass::SC_None, /* DefArg*/ nullptr );
308
+ auto lhsParamRefExpr = new (clangCtx) clang::DeclRefExpr (
309
+ clangCtx, lhsParamDecl, false , classTy, clang::ExprValueKind::VK_LValue,
310
+ clang::SourceLocation ());
311
+
312
+ auto rhsParamId = nullptr ;
313
+ auto rhsParamDecl = clang::ParmVarDecl::Create (
314
+ clangCtx, equalEqualDecl, clang::SourceLocation (),
315
+ clang::SourceLocation (), rhsParamId, classTy, classTyInfo,
316
+ clang::StorageClass::SC_None, nullptr );
317
+ auto rhsParamRefExpr = new (clangCtx) clang::DeclRefExpr (
318
+ clangCtx, rhsParamDecl, false , classTy, clang::ExprValueKind::VK_LValue,
319
+ clang::SourceLocation ());
320
+
321
+ equalEqualDecl->setParams ({lhsParamDecl, rhsParamDecl});
322
+
323
+ // Lookup the `operator==` function that will be called under the hood.
324
+ clang::UnresolvedSet<16 > operators;
325
+ // Note: calling `CreateOverloadedBinOp` emits an error if the looked up
326
+ // function is unavailable for the current target.
327
+ auto underlyingCallResult = clangSema.CreateOverloadedBinOp (
328
+ clang::SourceLocation (), operatorKind, operators, lhsParamRefExpr,
329
+ rhsParamRefExpr);
330
+ if (!underlyingCallResult.isUsable ())
331
+ return false ;
332
+ auto underlyingCall = underlyingCallResult.get ();
333
+
334
+ auto equalEqualBody = clang::ReturnStmt::Create (
335
+ clangCtx, clang::SourceLocation (), underlyingCall, nullptr );
336
+ equalEqualDecl->setBody (equalEqualBody);
337
+
338
+ auto lookupTable = impl.findLookupTable (classDecl);
339
+ addEntryToLookupTable (*lookupTable, equalEqualDecl, impl.getNameImporter ());
340
+ return true ;
263
341
}
264
342
265
343
bool swift::isIterator (const clang::CXXRecordDecl *clangDecl) {
@@ -349,15 +427,26 @@ void swift::conformToCxxIteratorIfNeeded(
349
427
if (!successorTy || successorTy->getAnyNominal () != decl)
350
428
return ;
351
429
352
- // If this is a templated class, `operator==` might be templated as well.
353
- // Try to instantiate it.
354
- if (auto templateSpec =
355
- dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
356
- instantiateTemplatedOperator (impl, templateSpec,
357
- clang::BinaryOperatorKind::BO_EQ);
358
- }
359
430
// Check if present: `func ==`
360
431
auto equalEqual = getEqualEqualOperator (decl);
432
+ if (!equalEqual) {
433
+ // If this class is inherited, `operator==` might be defined for a base
434
+ // class. If this is a templated class, `operator==` might be templated as
435
+ // well. Try to instantiate it.
436
+ clang::FunctionDecl* instantiated = instantiateTemplatedOperator (
437
+ impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
438
+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
439
+ // If `operator==` was instantiated successfully, try to find `func ==`
440
+ // again.
441
+ equalEqual = getEqualEqualOperator (decl);
442
+ if (!equalEqual) {
443
+ // If `func ==` still can't be found, it might be defined for a base
444
+ // class of the current class.
445
+ makeOperatorFunc (impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
446
+ equalEqual = getEqualEqualOperator (decl);
447
+ }
448
+ }
449
+ }
361
450
if (!equalEqual)
362
451
return ;
363
452
@@ -371,12 +460,19 @@ void swift::conformToCxxIteratorIfNeeded(
371
460
372
461
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
373
462
374
- if (auto templateSpec =
375
- dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
376
- instantiateTemplatedOperator (impl, templateSpec,
377
- clang::BinaryOperatorKind::BO_Sub);
463
+ // Check if present: `func -`
464
+ auto minus = getMinusOperator (decl);
465
+ if (!minus) {
466
+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
467
+ impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
468
+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
469
+ minus = getMinusOperator (decl);
470
+ if (!minus) {
471
+ makeOperatorFunc (impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
472
+ minus = getMinusOperator (decl);
473
+ }
474
+ }
378
475
}
379
- auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator (decl));
380
476
if (!minus)
381
477
return ;
382
478
auto distanceTy = minus->getResultInterfaceType ();
0 commit comments