@@ -156,7 +156,7 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
156
156
return lookupOperator (decl, decl->getASTContext ().Id_EqualsOperator , isValid);
157
157
}
158
158
159
- static ValueDecl *getMinusOperator (NominalTypeDecl *decl) {
159
+ static FuncDecl *getMinusOperator (NominalTypeDecl *decl) {
160
160
auto binaryIntegerProto =
161
161
decl->getASTContext ().getProtocol (KnownProtocolKind::BinaryInteger);
162
162
auto module = decl->getModuleContext ();
@@ -188,11 +188,12 @@ static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
188
188
return true ;
189
189
};
190
190
191
- return lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ),
192
- isValid);
191
+ ValueDecl *result =
192
+ lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ), isValid);
193
+ return dyn_cast_or_null<FuncDecl>(result);
193
194
}
194
195
195
- static ValueDecl *getPlusEqualOperator (NominalTypeDecl *decl, Type distanceTy) {
196
+ static FuncDecl *getPlusEqualOperator (NominalTypeDecl *decl, Type distanceTy) {
196
197
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
197
198
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
198
199
if (!plusEqual || !plusEqual->hasParameterList ())
@@ -219,14 +220,15 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
219
220
return true ;
220
221
};
221
222
222
- return lookupOperator (decl, decl->getASTContext ().getIdentifier (" +=" ),
223
- isValid);
223
+ ValueDecl *result =
224
+ lookupOperator (decl, decl->getASTContext ().getIdentifier (" +=" ), isValid);
225
+ return dyn_cast_or_null<FuncDecl>(result);
224
226
}
225
227
226
- static void instantiateTemplatedOperator (
227
- ClangImporter::Implementation &impl,
228
- const clang::ClassTemplateSpecializationDecl *classDecl,
229
- clang::BinaryOperatorKind operatorKind) {
228
+ static clang::FunctionDecl *
229
+ instantiateTemplatedOperator ( ClangImporter::Implementation &impl,
230
+ const clang::CXXRecordDecl *classDecl,
231
+ clang::BinaryOperatorKind operatorKind) {
230
232
231
233
clang::ASTContext &clangCtx = impl.getClangASTContext ();
232
234
clang::Sema &clangSema = impl.getClangSema ();
@@ -252,6 +254,7 @@ static void instantiateTemplatedOperator(
252
254
if (auto clangCallee = best->Function ) {
253
255
auto lookupTable = impl.findLookupTable (classDecl);
254
256
addEntryToLookupTable (*lookupTable, clangCallee, impl.getNameImporter ());
257
+ return clangCallee;
255
258
}
256
259
break ;
257
260
}
@@ -260,6 +263,94 @@ static void instantiateTemplatedOperator(
260
263
case clang::OR_Deleted:
261
264
break ;
262
265
}
266
+
267
+ return nullptr ;
268
+ }
269
+
270
+ // / Warning: This function emits an error and stops compilation if the
271
+ // / underlying operator function is unavailable in Swift for the current target
272
+ // / (see `clang::Sema::DiagnoseAvailabilityOfDecl`).
273
+ static bool synthesizeCXXOperator (ClangImporter::Implementation &impl,
274
+ const clang::CXXRecordDecl *classDecl,
275
+ clang::BinaryOperatorKind operatorKind,
276
+ clang::QualType lhsTy, clang::QualType rhsTy,
277
+ clang::QualType returnTy) {
278
+ auto &clangCtx = impl.getClangASTContext ();
279
+ auto &clangSema = impl.getClangSema ();
280
+
281
+ clang::OverloadedOperatorKind opKind =
282
+ clang::BinaryOperator::getOverloadedOperator (operatorKind);
283
+ const char *opSpelling = clang::getOperatorSpelling (opKind);
284
+
285
+ auto declName = clang::DeclarationName (&clangCtx.Idents .get (opSpelling));
286
+
287
+ // Determine the Clang decl context where the new operator function will be
288
+ // created. We use the translation unit as the decl context of the new
289
+ // operator, otherwise, the operator might get imported as a static member
290
+ // function of a different type (e.g. an operator declared inside of a C++
291
+ // namespace would get imported as a member function of a Swift enum), which
292
+ // would make the operator un-discoverable to Swift name lookup.
293
+ auto declContext =
294
+ const_cast <clang::CXXRecordDecl *>(classDecl)->getDeclContext ();
295
+ while (!declContext->isTranslationUnit ()) {
296
+ declContext = declContext->getParent ();
297
+ }
298
+
299
+ auto equalEqualTy = clangCtx.getFunctionType (
300
+ returnTy, {lhsTy, rhsTy}, clang::FunctionProtoType::ExtProtoInfo ());
301
+
302
+ // Create a `bool operator==(T, T)` function.
303
+ auto equalEqualDecl = clang::FunctionDecl::Create (
304
+ clangCtx, declContext, clang::SourceLocation (), clang::SourceLocation (),
305
+ declName, equalEqualTy, clangCtx.getTrivialTypeSourceInfo (returnTy),
306
+ clang::StorageClass::SC_Static);
307
+ equalEqualDecl->setImplicit ();
308
+ equalEqualDecl->setImplicitlyInline ();
309
+ // If this is a static member function of a class, it needs to be public.
310
+ equalEqualDecl->setAccess (clang::AccessSpecifier::AS_public);
311
+
312
+ // Create the parameters of the function. They are not referenced from source
313
+ // code, so they don't need to have a name.
314
+ auto lhsParamId = nullptr ;
315
+ auto lhsTyInfo = clangCtx.getTrivialTypeSourceInfo (lhsTy);
316
+ auto lhsParamDecl = clang::ParmVarDecl::Create (
317
+ clangCtx, equalEqualDecl, clang::SourceLocation (),
318
+ clang::SourceLocation (), lhsParamId, lhsTy, lhsTyInfo,
319
+ clang::StorageClass::SC_None, /* DefArg*/ nullptr );
320
+ auto lhsParamRefExpr = new (clangCtx) clang::DeclRefExpr (
321
+ clangCtx, lhsParamDecl, false , lhsTy, clang::ExprValueKind::VK_LValue,
322
+ clang::SourceLocation ());
323
+
324
+ auto rhsParamId = nullptr ;
325
+ auto rhsTyInfo = clangCtx.getTrivialTypeSourceInfo (rhsTy);
326
+ auto rhsParamDecl = clang::ParmVarDecl::Create (
327
+ clangCtx, equalEqualDecl, clang::SourceLocation (),
328
+ clang::SourceLocation (), rhsParamId, rhsTy, rhsTyInfo,
329
+ clang::StorageClass::SC_None, nullptr );
330
+ auto rhsParamRefExpr = new (clangCtx) clang::DeclRefExpr (
331
+ clangCtx, rhsParamDecl, false , rhsTy, clang::ExprValueKind::VK_LValue,
332
+ clang::SourceLocation ());
333
+
334
+ equalEqualDecl->setParams ({lhsParamDecl, rhsParamDecl});
335
+
336
+ // Lookup the `operator==` function that will be called under the hood.
337
+ clang::UnresolvedSet<16 > operators;
338
+ // Note: calling `CreateOverloadedBinOp` emits an error if the looked up
339
+ // function is unavailable for the current target.
340
+ auto underlyingCallResult = clangSema.CreateOverloadedBinOp (
341
+ clang::SourceLocation (), operatorKind, operators, lhsParamRefExpr,
342
+ rhsParamRefExpr);
343
+ if (!underlyingCallResult.isUsable ())
344
+ return false ;
345
+ auto underlyingCall = underlyingCallResult.get ();
346
+
347
+ auto equalEqualBody = clang::ReturnStmt::Create (
348
+ clangCtx, clang::SourceLocation (), underlyingCall, nullptr );
349
+ equalEqualDecl->setBody (equalEqualBody);
350
+
351
+ auto lookupTable = impl.findLookupTable (classDecl);
352
+ addEntryToLookupTable (*lookupTable, equalEqualDecl, impl.getNameImporter ());
353
+ return true ;
263
354
}
264
355
265
356
bool swift::isIterator (const clang::CXXRecordDecl *clangDecl) {
@@ -274,6 +365,7 @@ void swift::conformToCxxIteratorIfNeeded(
274
365
assert (decl);
275
366
assert (clangDecl);
276
367
ASTContext &ctx = decl->getASTContext ();
368
+ clang::ASTContext &clangCtx = clangDecl->getASTContext ();
277
369
278
370
if (!ctx.getProtocol (KnownProtocolKind::UnsafeCxxInputIterator))
279
371
return ;
@@ -349,15 +441,28 @@ void swift::conformToCxxIteratorIfNeeded(
349
441
if (!successorTy || successorTy->getAnyNominal () != decl)
350
442
return ;
351
443
352
- // If this is a templated class, `operator==` might be templated as well.
353
- // Try to instantiate it.
354
- if (auto templateSpec =
355
- dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
356
- instantiateTemplatedOperator (impl, templateSpec,
357
- clang::BinaryOperatorKind::BO_EQ);
358
- }
359
444
// Check if present: `func ==`
360
445
auto equalEqual = getEqualEqualOperator (decl);
446
+ if (!equalEqual) {
447
+ // If this class is inherited, `operator==` might be defined for a base
448
+ // class. If this is a templated class, `operator==` might be templated as
449
+ // well. Try to instantiate it.
450
+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
451
+ impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
452
+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
453
+ // If `operator==` was instantiated successfully, try to find `func ==`
454
+ // again.
455
+ equalEqual = getEqualEqualOperator (decl);
456
+ if (!equalEqual) {
457
+ // If `func ==` still can't be found, it might be defined for a base
458
+ // class of the current class.
459
+ auto paramTy = clangCtx.getRecordType (clangDecl);
460
+ synthesizeCXXOperator (impl, clangDecl, clang::BinaryOperatorKind::BO_EQ,
461
+ paramTy, paramTy, clangCtx.BoolTy );
462
+ equalEqual = getEqualEqualOperator (decl);
463
+ }
464
+ }
465
+ }
361
466
if (!equalEqual)
362
467
return ;
363
468
@@ -371,18 +476,46 @@ void swift::conformToCxxIteratorIfNeeded(
371
476
372
477
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
373
478
374
- if (auto templateSpec =
375
- dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
376
- instantiateTemplatedOperator (impl, templateSpec,
377
- clang::BinaryOperatorKind::BO_Sub);
479
+ // Check if present: `func -`
480
+ auto minus = getMinusOperator (decl);
481
+ if (!minus) {
482
+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
483
+ impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
484
+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
485
+ minus = getMinusOperator (decl);
486
+ if (!minus) {
487
+ clang::QualType returnTy = instantiated->getReturnType ();
488
+ auto paramTy = clangCtx.getRecordType (clangDecl);
489
+ synthesizeCXXOperator (impl, clangDecl,
490
+ clang::BinaryOperatorKind::BO_Sub, paramTy,
491
+ paramTy, returnTy);
492
+ minus = getMinusOperator (decl);
493
+ }
494
+ }
378
495
}
379
- auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator (decl));
380
496
if (!minus)
381
497
return ;
382
498
auto distanceTy = minus->getResultInterfaceType ();
383
499
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
384
500
385
- auto plusEqual = dyn_cast_or_null<FuncDecl>(getPlusEqualOperator (decl, distanceTy));
501
+ auto plusEqual = getPlusEqualOperator (decl, distanceTy);
502
+ if (!plusEqual) {
503
+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
504
+ impl, clangDecl, clang::BinaryOperatorKind::BO_AddAssign);
505
+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
506
+ plusEqual = getPlusEqualOperator (decl, distanceTy);
507
+ if (!plusEqual) {
508
+ clang::QualType returnTy = instantiated->getReturnType ();
509
+ auto clangMinus = cast<clang::FunctionDecl>(minus->getClangDecl ());
510
+ auto lhsTy = clangCtx.getRecordType (clangDecl);
511
+ auto rhsTy = clangMinus->getReturnType ();
512
+ synthesizeCXXOperator (impl, clangDecl,
513
+ clang::BinaryOperatorKind::BO_AddAssign, lhsTy,
514
+ rhsTy, returnTy);
515
+ plusEqual = getPlusEqualOperator (decl, distanceTy);
516
+ }
517
+ }
518
+ }
386
519
if (!plusEqual)
387
520
return ;
388
521
0 commit comments