Skip to content

Commit f450eb5

Browse files
authored
Merge pull request #67287 from apple/egorzhdan/cxx-equal-equal
[cxx-interop] Handle inherited templated operators during auto-conformance
2 parents 7d37cdf + bc56ddc commit f450eb5

12 files changed

+532
-25
lines changed

include/swift/ClangImporter/ClangImporter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ class ClangImporter final : public ClangModuleLoader {
568568
clang::FunctionTemplateDecl *func,
569569
SubstitutionMap subst) override;
570570

571+
bool isSynthesizedAndVisibleFromAllModules(const clang::Decl *decl);
572+
571573
bool isCXXMethodMutating(const clang::CXXMethodDecl *method) override;
572574

573575
bool isUnsafeCXXMethod(const FuncDecl *func) override;

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 157 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
156156
return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid);
157157
}
158158

159-
static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
159+
static FuncDecl *getMinusOperator(NominalTypeDecl *decl) {
160160
auto binaryIntegerProto =
161161
decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger);
162162
auto module = decl->getModuleContext();
@@ -188,11 +188,12 @@ static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
188188
return true;
189189
};
190190

191-
return lookupOperator(decl, decl->getASTContext().getIdentifier("-"),
192-
isValid);
191+
ValueDecl *result =
192+
lookupOperator(decl, decl->getASTContext().getIdentifier("-"), isValid);
193+
return dyn_cast_or_null<FuncDecl>(result);
193194
}
194195

195-
static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
196+
static FuncDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
196197
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
197198
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
198199
if (!plusEqual || !plusEqual->hasParameterList())
@@ -219,14 +220,15 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
219220
return true;
220221
};
221222

222-
return lookupOperator(decl, decl->getASTContext().getIdentifier("+="),
223-
isValid);
223+
ValueDecl *result =
224+
lookupOperator(decl, decl->getASTContext().getIdentifier("+="), isValid);
225+
return dyn_cast_or_null<FuncDecl>(result);
224226
}
225227

226-
static void instantiateTemplatedOperator(
227-
ClangImporter::Implementation &impl,
228-
const clang::ClassTemplateSpecializationDecl *classDecl,
229-
clang::BinaryOperatorKind operatorKind) {
228+
static clang::FunctionDecl *
229+
instantiateTemplatedOperator(ClangImporter::Implementation &impl,
230+
const clang::CXXRecordDecl *classDecl,
231+
clang::BinaryOperatorKind operatorKind) {
230232

231233
clang::ASTContext &clangCtx = impl.getClangASTContext();
232234
clang::Sema &clangSema = impl.getClangSema();
@@ -252,6 +254,7 @@ static void instantiateTemplatedOperator(
252254
if (auto clangCallee = best->Function) {
253255
auto lookupTable = impl.findLookupTable(classDecl);
254256
addEntryToLookupTable(*lookupTable, clangCallee, impl.getNameImporter());
257+
return clangCallee;
255258
}
256259
break;
257260
}
@@ -260,6 +263,95 @@ static void instantiateTemplatedOperator(
260263
case clang::OR_Deleted:
261264
break;
262265
}
266+
267+
return nullptr;
268+
}
269+
270+
/// Warning: This function emits an error and stops compilation if the
271+
/// underlying operator function is unavailable in Swift for the current target
272+
/// (see `clang::Sema::DiagnoseAvailabilityOfDecl`).
273+
static bool synthesizeCXXOperator(ClangImporter::Implementation &impl,
274+
const clang::CXXRecordDecl *classDecl,
275+
clang::BinaryOperatorKind operatorKind,
276+
clang::QualType lhsTy, clang::QualType rhsTy,
277+
clang::QualType returnTy) {
278+
auto &clangCtx = impl.getClangASTContext();
279+
auto &clangSema = impl.getClangSema();
280+
281+
clang::OverloadedOperatorKind opKind =
282+
clang::BinaryOperator::getOverloadedOperator(operatorKind);
283+
const char *opSpelling = clang::getOperatorSpelling(opKind);
284+
285+
auto declName = clang::DeclarationName(&clangCtx.Idents.get(opSpelling));
286+
287+
// Determine the Clang decl context where the new operator function will be
288+
// created. We use the translation unit as the decl context of the new
289+
// operator, otherwise, the operator might get imported as a static member
290+
// function of a different type (e.g. an operator declared inside of a C++
291+
// namespace would get imported as a member function of a Swift enum), which
292+
// would make the operator un-discoverable to Swift name lookup.
293+
auto declContext =
294+
const_cast<clang::CXXRecordDecl *>(classDecl)->getDeclContext();
295+
while (!declContext->isTranslationUnit()) {
296+
declContext = declContext->getParent();
297+
}
298+
299+
auto equalEqualTy = clangCtx.getFunctionType(
300+
returnTy, {lhsTy, rhsTy}, clang::FunctionProtoType::ExtProtoInfo());
301+
302+
// Create a `bool operator==(T, T)` function.
303+
auto equalEqualDecl = clang::FunctionDecl::Create(
304+
clangCtx, declContext, clang::SourceLocation(), clang::SourceLocation(),
305+
declName, equalEqualTy, clangCtx.getTrivialTypeSourceInfo(returnTy),
306+
clang::StorageClass::SC_Static);
307+
equalEqualDecl->setImplicit();
308+
equalEqualDecl->setImplicitlyInline();
309+
// If this is a static member function of a class, it needs to be public.
310+
equalEqualDecl->setAccess(clang::AccessSpecifier::AS_public);
311+
312+
// Create the parameters of the function. They are not referenced from source
313+
// code, so they don't need to have a name.
314+
auto lhsParamId = nullptr;
315+
auto lhsTyInfo = clangCtx.getTrivialTypeSourceInfo(lhsTy);
316+
auto lhsParamDecl = clang::ParmVarDecl::Create(
317+
clangCtx, equalEqualDecl, clang::SourceLocation(),
318+
clang::SourceLocation(), lhsParamId, lhsTy, lhsTyInfo,
319+
clang::StorageClass::SC_None, /*DefArg*/ nullptr);
320+
auto lhsParamRefExpr = new (clangCtx) clang::DeclRefExpr(
321+
clangCtx, lhsParamDecl, false, lhsTy, clang::ExprValueKind::VK_LValue,
322+
clang::SourceLocation());
323+
324+
auto rhsParamId = nullptr;
325+
auto rhsTyInfo = clangCtx.getTrivialTypeSourceInfo(rhsTy);
326+
auto rhsParamDecl = clang::ParmVarDecl::Create(
327+
clangCtx, equalEqualDecl, clang::SourceLocation(),
328+
clang::SourceLocation(), rhsParamId, rhsTy, rhsTyInfo,
329+
clang::StorageClass::SC_None, nullptr);
330+
auto rhsParamRefExpr = new (clangCtx) clang::DeclRefExpr(
331+
clangCtx, rhsParamDecl, false, rhsTy, clang::ExprValueKind::VK_LValue,
332+
clang::SourceLocation());
333+
334+
equalEqualDecl->setParams({lhsParamDecl, rhsParamDecl});
335+
336+
// Lookup the `operator==` function that will be called under the hood.
337+
clang::UnresolvedSet<16> operators;
338+
// Note: calling `CreateOverloadedBinOp` emits an error if the looked up
339+
// function is unavailable for the current target.
340+
auto underlyingCallResult = clangSema.CreateOverloadedBinOp(
341+
clang::SourceLocation(), operatorKind, operators, lhsParamRefExpr,
342+
rhsParamRefExpr);
343+
if (!underlyingCallResult.isUsable())
344+
return false;
345+
auto underlyingCall = underlyingCallResult.get();
346+
347+
auto equalEqualBody = clang::ReturnStmt::Create(
348+
clangCtx, clang::SourceLocation(), underlyingCall, nullptr);
349+
equalEqualDecl->setBody(equalEqualBody);
350+
351+
impl.synthesizedAndAlwaysVisibleDecls.insert(equalEqualDecl);
352+
auto lookupTable = impl.findLookupTable(classDecl);
353+
addEntryToLookupTable(*lookupTable, equalEqualDecl, impl.getNameImporter());
354+
return true;
263355
}
264356

265357
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
@@ -274,6 +366,7 @@ void swift::conformToCxxIteratorIfNeeded(
274366
assert(decl);
275367
assert(clangDecl);
276368
ASTContext &ctx = decl->getASTContext();
369+
clang::ASTContext &clangCtx = clangDecl->getASTContext();
277370

278371
if (!ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator))
279372
return;
@@ -349,15 +442,28 @@ void swift::conformToCxxIteratorIfNeeded(
349442
if (!successorTy || successorTy->getAnyNominal() != decl)
350443
return;
351444

352-
// If this is a templated class, `operator==` might be templated as well.
353-
// Try to instantiate it.
354-
if (auto templateSpec =
355-
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
356-
instantiateTemplatedOperator(impl, templateSpec,
357-
clang::BinaryOperatorKind::BO_EQ);
358-
}
359445
// Check if present: `func ==`
360446
auto equalEqual = getEqualEqualOperator(decl);
447+
if (!equalEqual) {
448+
// If this class is inherited, `operator==` might be defined for a base
449+
// class. If this is a templated class, `operator==` might be templated as
450+
// well. Try to instantiate it.
451+
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
452+
impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
453+
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
454+
// If `operator==` was instantiated successfully, try to find `func ==`
455+
// again.
456+
equalEqual = getEqualEqualOperator(decl);
457+
if (!equalEqual) {
458+
// If `func ==` still can't be found, it might be defined for a base
459+
// class of the current class.
460+
auto paramTy = clangCtx.getRecordType(clangDecl);
461+
synthesizeCXXOperator(impl, clangDecl, clang::BinaryOperatorKind::BO_EQ,
462+
paramTy, paramTy, clangCtx.BoolTy);
463+
equalEqual = getEqualEqualOperator(decl);
464+
}
465+
}
466+
}
361467
if (!equalEqual)
362468
return;
363469

@@ -371,18 +477,46 @@ void swift::conformToCxxIteratorIfNeeded(
371477

372478
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
373479

374-
if (auto templateSpec =
375-
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
376-
instantiateTemplatedOperator(impl, templateSpec,
377-
clang::BinaryOperatorKind::BO_Sub);
480+
// Check if present: `func -`
481+
auto minus = getMinusOperator(decl);
482+
if (!minus) {
483+
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
484+
impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
485+
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
486+
minus = getMinusOperator(decl);
487+
if (!minus) {
488+
clang::QualType returnTy = instantiated->getReturnType();
489+
auto paramTy = clangCtx.getRecordType(clangDecl);
490+
synthesizeCXXOperator(impl, clangDecl,
491+
clang::BinaryOperatorKind::BO_Sub, paramTy,
492+
paramTy, returnTy);
493+
minus = getMinusOperator(decl);
494+
}
495+
}
378496
}
379-
auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator(decl));
380497
if (!minus)
381498
return;
382499
auto distanceTy = minus->getResultInterfaceType();
383500
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
384501

385-
auto plusEqual = dyn_cast_or_null<FuncDecl>(getPlusEqualOperator(decl, distanceTy));
502+
auto plusEqual = getPlusEqualOperator(decl, distanceTy);
503+
if (!plusEqual) {
504+
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
505+
impl, clangDecl, clang::BinaryOperatorKind::BO_AddAssign);
506+
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
507+
plusEqual = getPlusEqualOperator(decl, distanceTy);
508+
if (!plusEqual) {
509+
clang::QualType returnTy = instantiated->getReturnType();
510+
auto clangMinus = cast<clang::FunctionDecl>(minus->getClangDecl());
511+
auto lhsTy = clangCtx.getRecordType(clangDecl);
512+
auto rhsTy = clangMinus->getReturnType();
513+
synthesizeCXXOperator(impl, clangDecl,
514+
clang::BinaryOperatorKind::BO_AddAssign, lhsTy,
515+
rhsTy, returnTy);
516+
plusEqual = getPlusEqualOperator(decl, distanceTy);
517+
}
518+
}
519+
}
386520
if (!plusEqual)
387521
return;
388522

lib/ClangImporter/ClangImporter.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2842,6 +2842,12 @@ static bool isVisibleFromModule(const ClangModuleUnit *ModuleFilter,
28422842
if (OwningClangModule == ModuleFilter->getClangModule())
28432843
return true;
28442844

2845+
// If this decl was implicitly synthesized by the compiler, and is not
2846+
// supposed to be owned by any module, return true.
2847+
if (Importer->isSynthesizedAndVisibleFromAllModules(D)) {
2848+
return true;
2849+
}
2850+
28452851
// Friends from class templates don't have an owning module. Just return true.
28462852
if (isa<clang::FunctionDecl>(D) &&
28472853
cast<clang::FunctionDecl>(D)->isThisDeclarationInstantiatedFromAFriendDefinition())
@@ -6302,6 +6308,11 @@ FuncDecl *ClangImporter::getCXXSynthesizedOperatorFunc(FuncDecl *decl) {
63026308
return cast<FuncDecl>(synthesizedOperator);
63036309
}
63046310

6311+
bool ClangImporter::isSynthesizedAndVisibleFromAllModules(
6312+
const clang::Decl *decl) {
6313+
return Impl.synthesizedAndAlwaysVisibleDecls.contains(decl);
6314+
}
6315+
63056316
bool ClangImporter::isCXXMethodMutating(const clang::CXXMethodDecl *method) {
63066317
if (isa<clang::CXXConstructorDecl>(method) || !method->isConst())
63076318
return true;

lib/ClangImporter/ImporterImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,8 @@ class LLVM_LIBRARY_VISIBILITY ClangImporter::Implementation
642642
llvm::MapVector<std::pair<NominalTypeDecl *, Type>,
643643
std::pair<FuncDecl *, FuncDecl *>> cxxSubscripts;
644644

645+
llvm::SmallPtrSet<const clang::Decl *, 1> synthesizedAndAlwaysVisibleDecls;
646+
645647
private:
646648
// Keep track of the decls that were already cloned for this specific class.
647649
llvm::DenseMap<std::pair<ValueDecl *, DeclContext *>, ValueDecl *>

test/Interop/Cxx/stdlib/overlay/Inputs/custom-collection.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,38 @@ struct SimpleCollectionReadOnly {
2828
const int& operator[](int index) const { return x[index]; }
2929
};
3030

31+
template <typename T>
32+
struct HasInheritedTemplatedConstRACIterator {
33+
public:
34+
typedef InheritedTemplatedConstRACIterator<int> iterator;
35+
36+
private:
37+
iterator b = iterator(1);
38+
iterator e = iterator(6);
39+
40+
public:
41+
iterator begin() const { return b; }
42+
iterator end() const { return e; }
43+
};
44+
45+
typedef HasInheritedTemplatedConstRACIterator<int>
46+
HasInheritedTemplatedConstRACIteratorInt;
47+
48+
template <typename T>
49+
struct HasInheritedTemplatedConstRACIteratorOutOfLineOps {
50+
public:
51+
typedef InheritedTemplatedConstRACIteratorOutOfLineOps<int> iterator;
52+
53+
private:
54+
iterator b = iterator(1);
55+
iterator e = iterator(4);
56+
57+
public:
58+
iterator begin() const { return b; }
59+
iterator end() const { return e; }
60+
};
61+
62+
typedef HasInheritedTemplatedConstRACIteratorOutOfLineOps<int>
63+
HasInheritedTemplatedConstRACIteratorOutOfLineOpsInt;
64+
3165
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_COLLECTION_H

0 commit comments

Comments
 (0)