Skip to content

Commit 59d7555

Browse files
committed
[cxx-interop] Handle inherited templated operators during auto-conformance
This fixes the automatic `std::unordered_map` conformance to CxxDictionary on Linux. Previously `std::unordered_map::const_iterator` was not auto-conformed to UnsafeCxxInputIterator because its `operator==` is defined on a templated base class of `const_iterator`. rdar://105220600
1 parent ddeab73 commit 59d7555

9 files changed

+441
-19
lines changed

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 115 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
156156
return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid);
157157
}
158158

159-
static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
159+
static FuncDecl *getMinusOperator(NominalTypeDecl *decl) {
160160
auto binaryIntegerProto =
161161
decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger);
162162
auto module = decl->getModuleContext();
@@ -188,8 +188,9 @@ static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
188188
return true;
189189
};
190190

191-
return lookupOperator(decl, decl->getASTContext().getIdentifier("-"),
192-
isValid);
191+
ValueDecl *result =
192+
lookupOperator(decl, decl->getASTContext().getIdentifier("-"), isValid);
193+
return dyn_cast_or_null<FuncDecl>(result);
193194
}
194195

195196
static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
@@ -223,10 +224,10 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
223224
isValid);
224225
}
225226

226-
static void instantiateTemplatedOperator(
227-
ClangImporter::Implementation &impl,
228-
const clang::ClassTemplateSpecializationDecl *classDecl,
229-
clang::BinaryOperatorKind operatorKind) {
227+
static clang::FunctionDecl *
228+
instantiateTemplatedOperator(ClangImporter::Implementation &impl,
229+
const clang::CXXRecordDecl *classDecl,
230+
clang::BinaryOperatorKind operatorKind) {
230231

231232
clang::ASTContext &clangCtx = impl.getClangASTContext();
232233
clang::Sema &clangSema = impl.getClangSema();
@@ -252,6 +253,7 @@ static void instantiateTemplatedOperator(
252253
if (auto clangCallee = best->Function) {
253254
auto lookupTable = impl.findLookupTable(classDecl);
254255
addEntryToLookupTable(*lookupTable, clangCallee, impl.getNameImporter());
256+
return clangCallee;
255257
}
256258
break;
257259
}
@@ -260,6 +262,82 @@ static void instantiateTemplatedOperator(
260262
case clang::OR_Deleted:
261263
break;
262264
}
265+
266+
return nullptr;
267+
}
268+
269+
/// Warning: This function emits an error and stops compilation if the
270+
/// underlying operator function is unavailable in Swift for the current target
271+
/// (see `clang::Sema::DiagnoseAvailabilityOfDecl`).
272+
static bool makeOperatorFunc(ClangImporter::Implementation &impl,
273+
const clang::CXXRecordDecl *classDecl,
274+
clang::BinaryOperatorKind operatorKind) {
275+
auto &clangCtx = impl.getClangASTContext();
276+
auto &clangSema = impl.getClangSema();
277+
auto classTy = clangCtx.getRecordType(classDecl);
278+
auto classTyInfo = clangCtx.getTrivialTypeSourceInfo(classTy);
279+
280+
clang::OverloadedOperatorKind opKind =
281+
clang::BinaryOperator::getOverloadedOperator(operatorKind);
282+
const char *opSpelling = clang::getOperatorSpelling(opKind);
283+
284+
auto declName = clang::DeclarationName(&clangCtx.Idents.get(opSpelling));
285+
auto declContext =
286+
const_cast<clang::CXXRecordDecl *>(classDecl)->getDeclContext();
287+
auto equalEqualTy =
288+
clangCtx.getFunctionType(clangCtx.BoolTy, {classTy, classTy},
289+
clang::FunctionProtoType::ExtProtoInfo());
290+
291+
// Create a `bool operator==(T, T)` function.
292+
auto equalEqualDecl = clang::FunctionDecl::Create(
293+
clangCtx, declContext, clang::SourceLocation(), clang::SourceLocation(),
294+
declName, equalEqualTy,
295+
clangCtx.getTrivialTypeSourceInfo(clangCtx.BoolTy),
296+
clang::StorageClass::SC_Static);
297+
equalEqualDecl->setImplicit();
298+
equalEqualDecl->setImplicitlyInline();
299+
equalEqualDecl->setAccess(clang::AccessSpecifier::AS_public);
300+
301+
// Create the parameters of the function. They are not referenced from source
302+
// code, so they don't need to have a name.
303+
auto lhsParamId = nullptr;
304+
auto lhsParamDecl = clang::ParmVarDecl::Create(
305+
clangCtx, equalEqualDecl, clang::SourceLocation(),
306+
clang::SourceLocation(), lhsParamId, classTy, classTyInfo,
307+
clang::StorageClass::SC_None, /*DefArg*/ nullptr);
308+
auto lhsParamRefExpr = new (clangCtx) clang::DeclRefExpr(
309+
clangCtx, lhsParamDecl, false, classTy, clang::ExprValueKind::VK_LValue,
310+
clang::SourceLocation());
311+
312+
auto rhsParamId = nullptr;
313+
auto rhsParamDecl = clang::ParmVarDecl::Create(
314+
clangCtx, equalEqualDecl, clang::SourceLocation(),
315+
clang::SourceLocation(), rhsParamId, classTy, classTyInfo,
316+
clang::StorageClass::SC_None, nullptr);
317+
auto rhsParamRefExpr = new (clangCtx) clang::DeclRefExpr(
318+
clangCtx, rhsParamDecl, false, classTy, clang::ExprValueKind::VK_LValue,
319+
clang::SourceLocation());
320+
321+
equalEqualDecl->setParams({lhsParamDecl, rhsParamDecl});
322+
323+
// Lookup the `operator==` function that will be called under the hood.
324+
clang::UnresolvedSet<16> operators;
325+
// Note: calling `CreateOverloadedBinOp` emits an error if the looked up
326+
// function is unavailable for the current target.
327+
auto underlyingCallResult = clangSema.CreateOverloadedBinOp(
328+
clang::SourceLocation(), operatorKind, operators, lhsParamRefExpr,
329+
rhsParamRefExpr);
330+
if (!underlyingCallResult.isUsable())
331+
return false;
332+
auto underlyingCall = underlyingCallResult.get();
333+
334+
auto equalEqualBody = clang::ReturnStmt::Create(
335+
clangCtx, clang::SourceLocation(), underlyingCall, nullptr);
336+
equalEqualDecl->setBody(equalEqualBody);
337+
338+
auto lookupTable = impl.findLookupTable(classDecl);
339+
addEntryToLookupTable(*lookupTable, equalEqualDecl, impl.getNameImporter());
340+
return true;
263341
}
264342

265343
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
@@ -349,15 +427,26 @@ void swift::conformToCxxIteratorIfNeeded(
349427
if (!successorTy || successorTy->getAnyNominal() != decl)
350428
return;
351429

352-
// If this is a templated class, `operator==` might be templated as well.
353-
// Try to instantiate it.
354-
if (auto templateSpec =
355-
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
356-
instantiateTemplatedOperator(impl, templateSpec,
357-
clang::BinaryOperatorKind::BO_EQ);
358-
}
359430
// Check if present: `func ==`
360431
auto equalEqual = getEqualEqualOperator(decl);
432+
if (!equalEqual) {
433+
// If this class is inherited, `operator==` might be defined for a base
434+
// class. If this is a templated class, `operator==` might be templated as
435+
// well. Try to instantiate it.
436+
clang::FunctionDecl* instantiated = instantiateTemplatedOperator(
437+
impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
438+
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
439+
// If `operator==` was instantiated successfully, try to find `func ==`
440+
// again.
441+
equalEqual = getEqualEqualOperator(decl);
442+
if (!equalEqual) {
443+
// If `func ==` still can't be found, it might be defined for a base
444+
// class of the current class.
445+
makeOperatorFunc(impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
446+
equalEqual = getEqualEqualOperator(decl);
447+
}
448+
}
449+
}
361450
if (!equalEqual)
362451
return;
363452

@@ -371,12 +460,19 @@ void swift::conformToCxxIteratorIfNeeded(
371460

372461
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
373462

374-
if (auto templateSpec =
375-
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
376-
instantiateTemplatedOperator(impl, templateSpec,
377-
clang::BinaryOperatorKind::BO_Sub);
463+
// Check if present: `func -`
464+
auto minus = getMinusOperator(decl);
465+
if (!minus) {
466+
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
467+
impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
468+
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
469+
minus = getMinusOperator(decl);
470+
if (!minus) {
471+
makeOperatorFunc(impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
472+
minus = getMinusOperator(decl);
473+
}
474+
}
378475
}
379-
auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator(decl));
380476
if (!minus)
381477
return;
382478
auto distanceTy = minus->getResultInterfaceType();

lib/ClangImporter/ClangImporter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2787,6 +2787,8 @@ getClangOwningModule(ClangNode Node, const clang::ASTContext &ClangCtx) {
27872787
// Let's use the owning module of the template pattern.
27882788
originalDecl = pattern;
27892789
}
2790+
if (functionDecl->isImplicit()) {
2791+
}
27902792
}
27912793

27922794
return ExtSource->getModule(originalDecl->getOwningModuleID());
@@ -2841,6 +2843,8 @@ static bool isVisibleFromModule(const ClangModuleUnit *ModuleFilter,
28412843
ClangASTContext);
28422844
if (OwningClangModule == ModuleFilter->getClangModule())
28432845
return true;
2846+
if (!OwningClangModule && D->isImplicit())
2847+
return true;
28442848

28452849
// Friends from class templates don't have an owning module. Just return true.
28462850
if (isa<clang::FunctionDecl>(D) &&

test/Interop/Cxx/stdlib/overlay/Inputs/custom-collection.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,21 @@ struct SimpleCollectionReadOnly {
2828
const int& operator[](int index) const { return x[index]; }
2929
};
3030

31+
template <typename T>
32+
struct HasInheritedTemplatedConstRACIterator {
33+
public:
34+
typedef InheritedTemplatedConstRACIterator<int> iterator;
35+
36+
private:
37+
iterator b = iterator(1);
38+
iterator e = iterator(6);
39+
40+
public:
41+
iterator begin() const { return b; }
42+
iterator end() const { return e; }
43+
};
44+
45+
typedef HasInheritedTemplatedConstRACIterator<int>
46+
HasInheritedTemplatedConstRACIteratorInt;
47+
3148
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_COLLECTION_H

0 commit comments

Comments
 (0)