Skip to content

Commit 206a585

Browse files
committed
[cxx-interop] Handle inherited templated operators during auto-conformance
This fixes the automatic `std::unordered_map` conformance to CxxDictionary on Linux. Previously `std::unordered_map::const_iterator` was not auto-conformed to UnsafeCxxInputIterator because its `operator==` is defined on a templated base class of `const_iterator`. rdar://105220600
1 parent aad2b6e commit 206a585

10 files changed

+523
-25
lines changed

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 156 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
156156
return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid);
157157
}
158158

159-
static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
159+
static FuncDecl *getMinusOperator(NominalTypeDecl *decl) {
160160
auto binaryIntegerProto =
161161
decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger);
162162
auto module = decl->getModuleContext();
@@ -188,11 +188,12 @@ static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
188188
return true;
189189
};
190190

191-
return lookupOperator(decl, decl->getASTContext().getIdentifier("-"),
192-
isValid);
191+
ValueDecl *result =
192+
lookupOperator(decl, decl->getASTContext().getIdentifier("-"), isValid);
193+
return dyn_cast_or_null<FuncDecl>(result);
193194
}
194195

195-
static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
196+
static FuncDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
196197
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
197198
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
198199
if (!plusEqual || !plusEqual->hasParameterList())
@@ -219,14 +220,15 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
219220
return true;
220221
};
221222

222-
return lookupOperator(decl, decl->getASTContext().getIdentifier("+="),
223-
isValid);
223+
ValueDecl *result =
224+
lookupOperator(decl, decl->getASTContext().getIdentifier("+="), isValid);
225+
return dyn_cast_or_null<FuncDecl>(result);
224226
}
225227

226-
static void instantiateTemplatedOperator(
227-
ClangImporter::Implementation &impl,
228-
const clang::ClassTemplateSpecializationDecl *classDecl,
229-
clang::BinaryOperatorKind operatorKind) {
228+
static clang::FunctionDecl *
229+
instantiateTemplatedOperator(ClangImporter::Implementation &impl,
230+
const clang::CXXRecordDecl *classDecl,
231+
clang::BinaryOperatorKind operatorKind) {
230232

231233
clang::ASTContext &clangCtx = impl.getClangASTContext();
232234
clang::Sema &clangSema = impl.getClangSema();
@@ -252,6 +254,7 @@ static void instantiateTemplatedOperator(
252254
if (auto clangCallee = best->Function) {
253255
auto lookupTable = impl.findLookupTable(classDecl);
254256
addEntryToLookupTable(*lookupTable, clangCallee, impl.getNameImporter());
257+
return clangCallee;
255258
}
256259
break;
257260
}
@@ -260,6 +263,94 @@ static void instantiateTemplatedOperator(
260263
case clang::OR_Deleted:
261264
break;
262265
}
266+
267+
return nullptr;
268+
}
269+
270+
/// Warning: This function emits an error and stops compilation if the
271+
/// underlying operator function is unavailable in Swift for the current target
272+
/// (see `clang::Sema::DiagnoseAvailabilityOfDecl`).
273+
static bool synthesizeCXXOperator(ClangImporter::Implementation &impl,
274+
const clang::CXXRecordDecl *classDecl,
275+
clang::BinaryOperatorKind operatorKind,
276+
clang::QualType lhsTy, clang::QualType rhsTy,
277+
clang::QualType returnTy) {
278+
auto &clangCtx = impl.getClangASTContext();
279+
auto &clangSema = impl.getClangSema();
280+
281+
clang::OverloadedOperatorKind opKind =
282+
clang::BinaryOperator::getOverloadedOperator(operatorKind);
283+
const char *opSpelling = clang::getOperatorSpelling(opKind);
284+
285+
auto declName = clang::DeclarationName(&clangCtx.Idents.get(opSpelling));
286+
287+
// Determine the Clang decl context where the new operator function will be
288+
// created. We use the translation unit as the decl context of the new
289+
// operator, otherwise, the operator might get imported as a static member
290+
// function of a different type (e.g. an operator declared inside of a C++
291+
// namespace would get imported as a member function of a Swift enum), which
292+
// would make the operator un-discoverable to Swift name lookup.
293+
auto declContext =
294+
const_cast<clang::CXXRecordDecl *>(classDecl)->getDeclContext();
295+
while (!declContext->isTranslationUnit()) {
296+
declContext = declContext->getParent();
297+
}
298+
299+
auto equalEqualTy = clangCtx.getFunctionType(
300+
returnTy, {lhsTy, rhsTy}, clang::FunctionProtoType::ExtProtoInfo());
301+
302+
// Create a `bool operator==(T, T)` function.
303+
auto equalEqualDecl = clang::FunctionDecl::Create(
304+
clangCtx, declContext, clang::SourceLocation(), clang::SourceLocation(),
305+
declName, equalEqualTy, clangCtx.getTrivialTypeSourceInfo(returnTy),
306+
clang::StorageClass::SC_Static);
307+
equalEqualDecl->setImplicit();
308+
equalEqualDecl->setImplicitlyInline();
309+
// If this is a static member function of a class, it needs to be public.
310+
equalEqualDecl->setAccess(clang::AccessSpecifier::AS_public);
311+
312+
// Create the parameters of the function. They are not referenced from source
313+
// code, so they don't need to have a name.
314+
auto lhsParamId = nullptr;
315+
auto lhsTyInfo = clangCtx.getTrivialTypeSourceInfo(lhsTy);
316+
auto lhsParamDecl = clang::ParmVarDecl::Create(
317+
clangCtx, equalEqualDecl, clang::SourceLocation(),
318+
clang::SourceLocation(), lhsParamId, lhsTy, lhsTyInfo,
319+
clang::StorageClass::SC_None, /*DefArg*/ nullptr);
320+
auto lhsParamRefExpr = new (clangCtx) clang::DeclRefExpr(
321+
clangCtx, lhsParamDecl, false, lhsTy, clang::ExprValueKind::VK_LValue,
322+
clang::SourceLocation());
323+
324+
auto rhsParamId = nullptr;
325+
auto rhsTyInfo = clangCtx.getTrivialTypeSourceInfo(rhsTy);
326+
auto rhsParamDecl = clang::ParmVarDecl::Create(
327+
clangCtx, equalEqualDecl, clang::SourceLocation(),
328+
clang::SourceLocation(), rhsParamId, rhsTy, rhsTyInfo,
329+
clang::StorageClass::SC_None, nullptr);
330+
auto rhsParamRefExpr = new (clangCtx) clang::DeclRefExpr(
331+
clangCtx, rhsParamDecl, false, rhsTy, clang::ExprValueKind::VK_LValue,
332+
clang::SourceLocation());
333+
334+
equalEqualDecl->setParams({lhsParamDecl, rhsParamDecl});
335+
336+
// Lookup the `operator==` function that will be called under the hood.
337+
clang::UnresolvedSet<16> operators;
338+
// Note: calling `CreateOverloadedBinOp` emits an error if the looked up
339+
// function is unavailable for the current target.
340+
auto underlyingCallResult = clangSema.CreateOverloadedBinOp(
341+
clang::SourceLocation(), operatorKind, operators, lhsParamRefExpr,
342+
rhsParamRefExpr);
343+
if (!underlyingCallResult.isUsable())
344+
return false;
345+
auto underlyingCall = underlyingCallResult.get();
346+
347+
auto equalEqualBody = clang::ReturnStmt::Create(
348+
clangCtx, clang::SourceLocation(), underlyingCall, nullptr);
349+
equalEqualDecl->setBody(equalEqualBody);
350+
351+
auto lookupTable = impl.findLookupTable(classDecl);
352+
addEntryToLookupTable(*lookupTable, equalEqualDecl, impl.getNameImporter());
353+
return true;
263354
}
264355

265356
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
@@ -274,6 +365,7 @@ void swift::conformToCxxIteratorIfNeeded(
274365
assert(decl);
275366
assert(clangDecl);
276367
ASTContext &ctx = decl->getASTContext();
368+
clang::ASTContext &clangCtx = clangDecl->getASTContext();
277369

278370
if (!ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator))
279371
return;
@@ -349,15 +441,28 @@ void swift::conformToCxxIteratorIfNeeded(
349441
if (!successorTy || successorTy->getAnyNominal() != decl)
350442
return;
351443

352-
// If this is a templated class, `operator==` might be templated as well.
353-
// Try to instantiate it.
354-
if (auto templateSpec =
355-
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
356-
instantiateTemplatedOperator(impl, templateSpec,
357-
clang::BinaryOperatorKind::BO_EQ);
358-
}
359444
// Check if present: `func ==`
360445
auto equalEqual = getEqualEqualOperator(decl);
446+
if (!equalEqual) {
447+
// If this class is inherited, `operator==` might be defined for a base
448+
// class. If this is a templated class, `operator==` might be templated as
449+
// well. Try to instantiate it.
450+
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
451+
impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
452+
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
453+
// If `operator==` was instantiated successfully, try to find `func ==`
454+
// again.
455+
equalEqual = getEqualEqualOperator(decl);
456+
if (!equalEqual) {
457+
// If `func ==` still can't be found, it might be defined for a base
458+
// class of the current class.
459+
auto paramTy = clangCtx.getRecordType(clangDecl);
460+
synthesizeCXXOperator(impl, clangDecl, clang::BinaryOperatorKind::BO_EQ,
461+
paramTy, paramTy, clangCtx.BoolTy);
462+
equalEqual = getEqualEqualOperator(decl);
463+
}
464+
}
465+
}
361466
if (!equalEqual)
362467
return;
363468

@@ -371,18 +476,46 @@ void swift::conformToCxxIteratorIfNeeded(
371476

372477
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
373478

374-
if (auto templateSpec =
375-
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
376-
instantiateTemplatedOperator(impl, templateSpec,
377-
clang::BinaryOperatorKind::BO_Sub);
479+
// Check if present: `func -`
480+
auto minus = getMinusOperator(decl);
481+
if (!minus) {
482+
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
483+
impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
484+
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
485+
minus = getMinusOperator(decl);
486+
if (!minus) {
487+
clang::QualType returnTy = instantiated->getReturnType();
488+
auto paramTy = clangCtx.getRecordType(clangDecl);
489+
synthesizeCXXOperator(impl, clangDecl,
490+
clang::BinaryOperatorKind::BO_Sub, paramTy,
491+
paramTy, returnTy);
492+
minus = getMinusOperator(decl);
493+
}
494+
}
378495
}
379-
auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator(decl));
380496
if (!minus)
381497
return;
382498
auto distanceTy = minus->getResultInterfaceType();
383499
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
384500

385-
auto plusEqual = dyn_cast_or_null<FuncDecl>(getPlusEqualOperator(decl, distanceTy));
501+
auto plusEqual = getPlusEqualOperator(decl, distanceTy);
502+
if (!plusEqual) {
503+
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
504+
impl, clangDecl, clang::BinaryOperatorKind::BO_AddAssign);
505+
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
506+
plusEqual = getPlusEqualOperator(decl, distanceTy);
507+
if (!plusEqual) {
508+
clang::QualType returnTy = instantiated->getReturnType();
509+
auto clangMinus = cast<clang::FunctionDecl>(minus->getClangDecl());
510+
auto lhsTy = clangCtx.getRecordType(clangDecl);
511+
auto rhsTy = clangMinus->getReturnType();
512+
synthesizeCXXOperator(impl, clangDecl,
513+
clang::BinaryOperatorKind::BO_AddAssign, lhsTy,
514+
rhsTy, returnTy);
515+
plusEqual = getPlusEqualOperator(decl, distanceTy);
516+
}
517+
}
518+
}
386519
if (!plusEqual)
387520
return;
388521

lib/ClangImporter/ClangImporter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2842,6 +2842,13 @@ static bool isVisibleFromModule(const ClangModuleUnit *ModuleFilter,
28422842
if (OwningClangModule == ModuleFilter->getClangModule())
28432843
return true;
28442844

2845+
// If this decl was implicitly synthesized by the compiler, and is not
2846+
// supposed to be owned by any module, return true.
2847+
if (D->isImplicit() && D->getModuleOwnershipKind() ==
2848+
clang::Decl::ModuleOwnershipKind::Unowned) {
2849+
return true;
2850+
}
2851+
28452852
// Friends from class templates don't have an owning module. Just return true.
28462853
if (isa<clang::FunctionDecl>(D) &&
28472854
cast<clang::FunctionDecl>(D)->isThisDeclarationInstantiatedFromAFriendDefinition())

test/Interop/Cxx/stdlib/overlay/Inputs/custom-collection.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,38 @@ struct SimpleCollectionReadOnly {
2828
const int& operator[](int index) const { return x[index]; }
2929
};
3030

31+
template <typename T>
32+
struct HasInheritedTemplatedConstRACIterator {
33+
public:
34+
typedef InheritedTemplatedConstRACIterator<int> iterator;
35+
36+
private:
37+
iterator b = iterator(1);
38+
iterator e = iterator(6);
39+
40+
public:
41+
iterator begin() const { return b; }
42+
iterator end() const { return e; }
43+
};
44+
45+
typedef HasInheritedTemplatedConstRACIterator<int>
46+
HasInheritedTemplatedConstRACIteratorInt;
47+
48+
template <typename T>
49+
struct HasInheritedTemplatedConstRACIteratorOutOfLineOps {
50+
public:
51+
typedef InheritedTemplatedConstRACIteratorOutOfLineOps<int> iterator;
52+
53+
private:
54+
iterator b = iterator(1);
55+
iterator e = iterator(4);
56+
57+
public:
58+
iterator begin() const { return b; }
59+
iterator end() const { return e; }
60+
};
61+
62+
typedef HasInheritedTemplatedConstRACIteratorOutOfLineOps<int>
63+
HasInheritedTemplatedConstRACIteratorOutOfLineOpsInt;
64+
3165
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_COLLECTION_H

0 commit comments

Comments
 (0)