@@ -156,7 +156,7 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
156
156
return lookupOperator (decl, decl->getASTContext ().Id_EqualsOperator , isValid);
157
157
}
158
158
159
- static ValueDecl *getMinusOperator (NominalTypeDecl *decl) {
159
+ static FuncDecl *getMinusOperator (NominalTypeDecl *decl) {
160
160
auto binaryIntegerProto =
161
161
decl->getASTContext ().getProtocol (KnownProtocolKind::BinaryInteger);
162
162
auto module = decl->getModuleContext ();
@@ -188,11 +188,12 @@ static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
188
188
return true ;
189
189
};
190
190
191
- return lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ),
192
- isValid);
191
+ ValueDecl *result =
192
+ lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ), isValid);
193
+ return dyn_cast_or_null<FuncDecl>(result);
193
194
}
194
195
195
- static ValueDecl *getPlusEqualOperator (NominalTypeDecl *decl, Type distanceTy) {
196
+ static FuncDecl *getPlusEqualOperator (NominalTypeDecl *decl, Type distanceTy) {
196
197
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
197
198
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
198
199
if (!plusEqual || !plusEqual->hasParameterList ())
@@ -219,14 +220,15 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
219
220
return true ;
220
221
};
221
222
222
- return lookupOperator (decl, decl->getASTContext ().getIdentifier (" +=" ),
223
- isValid);
223
+ ValueDecl *result =
224
+ lookupOperator (decl, decl->getASTContext ().getIdentifier (" +=" ), isValid);
225
+ return dyn_cast_or_null<FuncDecl>(result);
224
226
}
225
227
226
- static void instantiateTemplatedOperator (
227
- ClangImporter::Implementation &impl,
228
- const clang::ClassTemplateSpecializationDecl *classDecl,
229
- clang::BinaryOperatorKind operatorKind) {
228
+ static clang::FunctionDecl *
229
+ instantiateTemplatedOperator ( ClangImporter::Implementation &impl,
230
+ const clang::CXXRecordDecl *classDecl,
231
+ clang::BinaryOperatorKind operatorKind) {
230
232
231
233
clang::ASTContext &clangCtx = impl.getClangASTContext ();
232
234
clang::Sema &clangSema = impl.getClangSema ();
@@ -252,6 +254,7 @@ static void instantiateTemplatedOperator(
252
254
if (auto clangCallee = best->Function ) {
253
255
auto lookupTable = impl.findLookupTable (classDecl);
254
256
addEntryToLookupTable (*lookupTable, clangCallee, impl.getNameImporter ());
257
+ return clangCallee;
255
258
}
256
259
break ;
257
260
}
@@ -260,6 +263,95 @@ static void instantiateTemplatedOperator(
260
263
case clang::OR_Deleted:
261
264
break ;
262
265
}
266
+
267
+ return nullptr ;
268
+ }
269
+
270
+ // / Warning: This function emits an error and stops compilation if the
271
+ // / underlying operator function is unavailable in Swift for the current target
272
+ // / (see `clang::Sema::DiagnoseAvailabilityOfDecl`).
273
+ static bool synthesizeCXXOperator (ClangImporter::Implementation &impl,
274
+ const clang::CXXRecordDecl *classDecl,
275
+ clang::BinaryOperatorKind operatorKind,
276
+ clang::QualType lhsTy, clang::QualType rhsTy,
277
+ clang::QualType returnTy) {
278
+ auto &clangCtx = impl.getClangASTContext ();
279
+ auto &clangSema = impl.getClangSema ();
280
+
281
+ clang::OverloadedOperatorKind opKind =
282
+ clang::BinaryOperator::getOverloadedOperator (operatorKind);
283
+ const char *opSpelling = clang::getOperatorSpelling (opKind);
284
+
285
+ auto declName = clang::DeclarationName (&clangCtx.Idents .get (opSpelling));
286
+
287
+ // Determine the Clang decl context where the new operator function will be
288
+ // created. We use the translation unit as the decl context of the new
289
+ // operator, otherwise, the operator might get imported as a static member
290
+ // function of a different type (e.g. an operator declared inside of a C++
291
+ // namespace would get imported as a member function of a Swift enum), which
292
+ // would make the operator un-discoverable to Swift name lookup.
293
+ auto declContext =
294
+ const_cast <clang::CXXRecordDecl *>(classDecl)->getDeclContext ();
295
+ while (!declContext->isTranslationUnit ()) {
296
+ declContext = declContext->getParent ();
297
+ }
298
+
299
+ auto equalEqualTy = clangCtx.getFunctionType (
300
+ returnTy, {lhsTy, rhsTy}, clang::FunctionProtoType::ExtProtoInfo ());
301
+
302
+ // Create a `bool operator==(T, T)` function.
303
+ auto equalEqualDecl = clang::FunctionDecl::Create (
304
+ clangCtx, declContext, clang::SourceLocation (), clang::SourceLocation (),
305
+ declName, equalEqualTy, clangCtx.getTrivialTypeSourceInfo (returnTy),
306
+ clang::StorageClass::SC_Static);
307
+ equalEqualDecl->setImplicit ();
308
+ equalEqualDecl->setImplicitlyInline ();
309
+ // If this is a static member function of a class, it needs to be public.
310
+ equalEqualDecl->setAccess (clang::AccessSpecifier::AS_public);
311
+
312
+ // Create the parameters of the function. They are not referenced from source
313
+ // code, so they don't need to have a name.
314
+ auto lhsParamId = nullptr ;
315
+ auto lhsTyInfo = clangCtx.getTrivialTypeSourceInfo (lhsTy);
316
+ auto lhsParamDecl = clang::ParmVarDecl::Create (
317
+ clangCtx, equalEqualDecl, clang::SourceLocation (),
318
+ clang::SourceLocation (), lhsParamId, lhsTy, lhsTyInfo,
319
+ clang::StorageClass::SC_None, /* DefArg*/ nullptr );
320
+ auto lhsParamRefExpr = new (clangCtx) clang::DeclRefExpr (
321
+ clangCtx, lhsParamDecl, false , lhsTy, clang::ExprValueKind::VK_LValue,
322
+ clang::SourceLocation ());
323
+
324
+ auto rhsParamId = nullptr ;
325
+ auto rhsTyInfo = clangCtx.getTrivialTypeSourceInfo (rhsTy);
326
+ auto rhsParamDecl = clang::ParmVarDecl::Create (
327
+ clangCtx, equalEqualDecl, clang::SourceLocation (),
328
+ clang::SourceLocation (), rhsParamId, rhsTy, rhsTyInfo,
329
+ clang::StorageClass::SC_None, nullptr );
330
+ auto rhsParamRefExpr = new (clangCtx) clang::DeclRefExpr (
331
+ clangCtx, rhsParamDecl, false , rhsTy, clang::ExprValueKind::VK_LValue,
332
+ clang::SourceLocation ());
333
+
334
+ equalEqualDecl->setParams ({lhsParamDecl, rhsParamDecl});
335
+
336
+ // Lookup the `operator==` function that will be called under the hood.
337
+ clang::UnresolvedSet<16 > operators;
338
+ // Note: calling `CreateOverloadedBinOp` emits an error if the looked up
339
+ // function is unavailable for the current target.
340
+ auto underlyingCallResult = clangSema.CreateOverloadedBinOp (
341
+ clang::SourceLocation (), operatorKind, operators, lhsParamRefExpr,
342
+ rhsParamRefExpr);
343
+ if (!underlyingCallResult.isUsable ())
344
+ return false ;
345
+ auto underlyingCall = underlyingCallResult.get ();
346
+
347
+ auto equalEqualBody = clang::ReturnStmt::Create (
348
+ clangCtx, clang::SourceLocation (), underlyingCall, nullptr );
349
+ equalEqualDecl->setBody (equalEqualBody);
350
+
351
+ impl.synthesizedAndAlwaysVisibleDecls .insert (equalEqualDecl);
352
+ auto lookupTable = impl.findLookupTable (classDecl);
353
+ addEntryToLookupTable (*lookupTable, equalEqualDecl, impl.getNameImporter ());
354
+ return true ;
263
355
}
264
356
265
357
bool swift::isIterator (const clang::CXXRecordDecl *clangDecl) {
@@ -274,6 +366,7 @@ void swift::conformToCxxIteratorIfNeeded(
274
366
assert (decl);
275
367
assert (clangDecl);
276
368
ASTContext &ctx = decl->getASTContext ();
369
+ clang::ASTContext &clangCtx = clangDecl->getASTContext ();
277
370
278
371
if (!ctx.getProtocol (KnownProtocolKind::UnsafeCxxInputIterator))
279
372
return ;
@@ -349,15 +442,28 @@ void swift::conformToCxxIteratorIfNeeded(
349
442
if (!successorTy || successorTy->getAnyNominal () != decl)
350
443
return ;
351
444
352
- // If this is a templated class, `operator==` might be templated as well.
353
- // Try to instantiate it.
354
- if (auto templateSpec =
355
- dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
356
- instantiateTemplatedOperator (impl, templateSpec,
357
- clang::BinaryOperatorKind::BO_EQ);
358
- }
359
445
// Check if present: `func ==`
360
446
auto equalEqual = getEqualEqualOperator (decl);
447
+ if (!equalEqual) {
448
+ // If this class is inherited, `operator==` might be defined for a base
449
+ // class. If this is a templated class, `operator==` might be templated as
450
+ // well. Try to instantiate it.
451
+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
452
+ impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
453
+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
454
+ // If `operator==` was instantiated successfully, try to find `func ==`
455
+ // again.
456
+ equalEqual = getEqualEqualOperator (decl);
457
+ if (!equalEqual) {
458
+ // If `func ==` still can't be found, it might be defined for a base
459
+ // class of the current class.
460
+ auto paramTy = clangCtx.getRecordType (clangDecl);
461
+ synthesizeCXXOperator (impl, clangDecl, clang::BinaryOperatorKind::BO_EQ,
462
+ paramTy, paramTy, clangCtx.BoolTy );
463
+ equalEqual = getEqualEqualOperator (decl);
464
+ }
465
+ }
466
+ }
361
467
if (!equalEqual)
362
468
return ;
363
469
@@ -371,18 +477,46 @@ void swift::conformToCxxIteratorIfNeeded(
371
477
372
478
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
373
479
374
- if (auto templateSpec =
375
- dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
376
- instantiateTemplatedOperator (impl, templateSpec,
377
- clang::BinaryOperatorKind::BO_Sub);
480
+ // Check if present: `func -`
481
+ auto minus = getMinusOperator (decl);
482
+ if (!minus) {
483
+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
484
+ impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
485
+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
486
+ minus = getMinusOperator (decl);
487
+ if (!minus) {
488
+ clang::QualType returnTy = instantiated->getReturnType ();
489
+ auto paramTy = clangCtx.getRecordType (clangDecl);
490
+ synthesizeCXXOperator (impl, clangDecl,
491
+ clang::BinaryOperatorKind::BO_Sub, paramTy,
492
+ paramTy, returnTy);
493
+ minus = getMinusOperator (decl);
494
+ }
495
+ }
378
496
}
379
- auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator (decl));
380
497
if (!minus)
381
498
return ;
382
499
auto distanceTy = minus->getResultInterfaceType ();
383
500
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
384
501
385
- auto plusEqual = dyn_cast_or_null<FuncDecl>(getPlusEqualOperator (decl, distanceTy));
502
+ auto plusEqual = getPlusEqualOperator (decl, distanceTy);
503
+ if (!plusEqual) {
504
+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
505
+ impl, clangDecl, clang::BinaryOperatorKind::BO_AddAssign);
506
+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
507
+ plusEqual = getPlusEqualOperator (decl, distanceTy);
508
+ if (!plusEqual) {
509
+ clang::QualType returnTy = instantiated->getReturnType ();
510
+ auto clangMinus = cast<clang::FunctionDecl>(minus->getClangDecl ());
511
+ auto lhsTy = clangCtx.getRecordType (clangDecl);
512
+ auto rhsTy = clangMinus->getReturnType ();
513
+ synthesizeCXXOperator (impl, clangDecl,
514
+ clang::BinaryOperatorKind::BO_AddAssign, lhsTy,
515
+ rhsTy, returnTy);
516
+ plusEqual = getPlusEqualOperator (decl, distanceTy);
517
+ }
518
+ }
519
+ }
386
520
if (!plusEqual)
387
521
return ;
388
522
0 commit comments