Skip to content

🍒[cxx-interop] Handle inherited templated operators during auto-conformance #67373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/ClangImporter/ClangImporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ class ClangImporter final : public ClangModuleLoader {
clang::FunctionTemplateDecl *func,
SubstitutionMap subst) override;

bool isSynthesizedAndVisibleFromAllModules(const clang::Decl *decl);

bool isCXXMethodMutating(const clang::CXXMethodDecl *method) override;

bool isUnsafeCXXMethod(const FuncDecl *func) override;
Expand Down
180 changes: 157 additions & 23 deletions lib/ClangImporter/ClangDerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid);
}

static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
static FuncDecl *getMinusOperator(NominalTypeDecl *decl) {
auto binaryIntegerProto =
decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger);
auto module = decl->getModuleContext();
Expand Down Expand Up @@ -188,11 +188,12 @@ static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
return true;
};

return lookupOperator(decl, decl->getASTContext().getIdentifier("-"),
isValid);
ValueDecl *result =
lookupOperator(decl, decl->getASTContext().getIdentifier("-"), isValid);
return dyn_cast_or_null<FuncDecl>(result);
}

static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
static FuncDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
if (!plusEqual || !plusEqual->hasParameterList())
Expand All @@ -219,14 +220,15 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
return true;
};

return lookupOperator(decl, decl->getASTContext().getIdentifier("+="),
isValid);
ValueDecl *result =
lookupOperator(decl, decl->getASTContext().getIdentifier("+="), isValid);
return dyn_cast_or_null<FuncDecl>(result);
}

static void instantiateTemplatedOperator(
ClangImporter::Implementation &impl,
const clang::ClassTemplateSpecializationDecl *classDecl,
clang::BinaryOperatorKind operatorKind) {
static clang::FunctionDecl *
instantiateTemplatedOperator(ClangImporter::Implementation &impl,
const clang::CXXRecordDecl *classDecl,
clang::BinaryOperatorKind operatorKind) {

clang::ASTContext &clangCtx = impl.getClangASTContext();
clang::Sema &clangSema = impl.getClangSema();
Expand All @@ -252,6 +254,7 @@ static void instantiateTemplatedOperator(
if (auto clangCallee = best->Function) {
auto lookupTable = impl.findLookupTable(classDecl);
addEntryToLookupTable(*lookupTable, clangCallee, impl.getNameImporter());
return clangCallee;
}
break;
}
Expand All @@ -260,6 +263,95 @@ static void instantiateTemplatedOperator(
case clang::OR_Deleted:
break;
}

return nullptr;
}

/// Warning: This function emits an error and stops compilation if the
/// underlying operator function is unavailable in Swift for the current target
/// (see `clang::Sema::DiagnoseAvailabilityOfDecl`).
static bool synthesizeCXXOperator(ClangImporter::Implementation &impl,
const clang::CXXRecordDecl *classDecl,
clang::BinaryOperatorKind operatorKind,
clang::QualType lhsTy, clang::QualType rhsTy,
clang::QualType returnTy) {
auto &clangCtx = impl.getClangASTContext();
auto &clangSema = impl.getClangSema();

clang::OverloadedOperatorKind opKind =
clang::BinaryOperator::getOverloadedOperator(operatorKind);
const char *opSpelling = clang::getOperatorSpelling(opKind);

auto declName = clang::DeclarationName(&clangCtx.Idents.get(opSpelling));

// Determine the Clang decl context where the new operator function will be
// created. We use the translation unit as the decl context of the new
// operator, otherwise, the operator might get imported as a static member
// function of a different type (e.g. an operator declared inside of a C++
// namespace would get imported as a member function of a Swift enum), which
// would make the operator un-discoverable to Swift name lookup.
auto declContext =
const_cast<clang::CXXRecordDecl *>(classDecl)->getDeclContext();
while (!declContext->isTranslationUnit()) {
declContext = declContext->getParent();
}

auto equalEqualTy = clangCtx.getFunctionType(
returnTy, {lhsTy, rhsTy}, clang::FunctionProtoType::ExtProtoInfo());

// Create a `bool operator==(T, T)` function.
auto equalEqualDecl = clang::FunctionDecl::Create(
clangCtx, declContext, clang::SourceLocation(), clang::SourceLocation(),
declName, equalEqualTy, clangCtx.getTrivialTypeSourceInfo(returnTy),
clang::StorageClass::SC_Static);
equalEqualDecl->setImplicit();
equalEqualDecl->setImplicitlyInline();
// If this is a static member function of a class, it needs to be public.
equalEqualDecl->setAccess(clang::AccessSpecifier::AS_public);

// Create the parameters of the function. They are not referenced from source
// code, so they don't need to have a name.
auto lhsParamId = nullptr;
auto lhsTyInfo = clangCtx.getTrivialTypeSourceInfo(lhsTy);
auto lhsParamDecl = clang::ParmVarDecl::Create(
clangCtx, equalEqualDecl, clang::SourceLocation(),
clang::SourceLocation(), lhsParamId, lhsTy, lhsTyInfo,
clang::StorageClass::SC_None, /*DefArg*/ nullptr);
auto lhsParamRefExpr = new (clangCtx) clang::DeclRefExpr(
clangCtx, lhsParamDecl, false, lhsTy, clang::ExprValueKind::VK_LValue,
clang::SourceLocation());

auto rhsParamId = nullptr;
auto rhsTyInfo = clangCtx.getTrivialTypeSourceInfo(rhsTy);
auto rhsParamDecl = clang::ParmVarDecl::Create(
clangCtx, equalEqualDecl, clang::SourceLocation(),
clang::SourceLocation(), rhsParamId, rhsTy, rhsTyInfo,
clang::StorageClass::SC_None, nullptr);
auto rhsParamRefExpr = new (clangCtx) clang::DeclRefExpr(
clangCtx, rhsParamDecl, false, rhsTy, clang::ExprValueKind::VK_LValue,
clang::SourceLocation());

equalEqualDecl->setParams({lhsParamDecl, rhsParamDecl});

// Lookup the `operator==` function that will be called under the hood.
clang::UnresolvedSet<16> operators;
// Note: calling `CreateOverloadedBinOp` emits an error if the looked up
// function is unavailable for the current target.
auto underlyingCallResult = clangSema.CreateOverloadedBinOp(
clang::SourceLocation(), operatorKind, operators, lhsParamRefExpr,
rhsParamRefExpr);
if (!underlyingCallResult.isUsable())
return false;
auto underlyingCall = underlyingCallResult.get();

auto equalEqualBody = clang::ReturnStmt::Create(
clangCtx, clang::SourceLocation(), underlyingCall, nullptr);
equalEqualDecl->setBody(equalEqualBody);

impl.synthesizedAndAlwaysVisibleDecls.insert(equalEqualDecl);
auto lookupTable = impl.findLookupTable(classDecl);
addEntryToLookupTable(*lookupTable, equalEqualDecl, impl.getNameImporter());
return true;
}

bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
Expand All @@ -274,6 +366,7 @@ void swift::conformToCxxIteratorIfNeeded(
assert(decl);
assert(clangDecl);
ASTContext &ctx = decl->getASTContext();
clang::ASTContext &clangCtx = clangDecl->getASTContext();

if (!ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator))
return;
Expand Down Expand Up @@ -349,15 +442,28 @@ void swift::conformToCxxIteratorIfNeeded(
if (!successorTy || successorTy->getAnyNominal() != decl)
return;

// If this is a templated class, `operator==` might be templated as well.
// Try to instantiate it.
if (auto templateSpec =
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
instantiateTemplatedOperator(impl, templateSpec,
clang::BinaryOperatorKind::BO_EQ);
}
// Check if present: `func ==`
auto equalEqual = getEqualEqualOperator(decl);
if (!equalEqual) {
// If this class is inherited, `operator==` might be defined for a base
// class. If this is a templated class, `operator==` might be templated as
// well. Try to instantiate it.
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
// If `operator==` was instantiated successfully, try to find `func ==`
// again.
equalEqual = getEqualEqualOperator(decl);
if (!equalEqual) {
// If `func ==` still can't be found, it might be defined for a base
// class of the current class.
auto paramTy = clangCtx.getRecordType(clangDecl);
synthesizeCXXOperator(impl, clangDecl, clang::BinaryOperatorKind::BO_EQ,
paramTy, paramTy, clangCtx.BoolTy);
equalEqual = getEqualEqualOperator(decl);
}
}
}
if (!equalEqual)
return;

Expand All @@ -371,18 +477,46 @@ void swift::conformToCxxIteratorIfNeeded(

// Try to conform to UnsafeCxxRandomAccessIterator if possible.

if (auto templateSpec =
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
instantiateTemplatedOperator(impl, templateSpec,
clang::BinaryOperatorKind::BO_Sub);
// Check if present: `func -`
auto minus = getMinusOperator(decl);
if (!minus) {
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
minus = getMinusOperator(decl);
if (!minus) {
clang::QualType returnTy = instantiated->getReturnType();
auto paramTy = clangCtx.getRecordType(clangDecl);
synthesizeCXXOperator(impl, clangDecl,
clang::BinaryOperatorKind::BO_Sub, paramTy,
paramTy, returnTy);
minus = getMinusOperator(decl);
}
}
}
auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator(decl));
if (!minus)
return;
auto distanceTy = minus->getResultInterfaceType();
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.

auto plusEqual = dyn_cast_or_null<FuncDecl>(getPlusEqualOperator(decl, distanceTy));
auto plusEqual = getPlusEqualOperator(decl, distanceTy);
if (!plusEqual) {
clang::FunctionDecl *instantiated = instantiateTemplatedOperator(
impl, clangDecl, clang::BinaryOperatorKind::BO_AddAssign);
if (instantiated && !impl.isUnavailableInSwift(instantiated)) {
plusEqual = getPlusEqualOperator(decl, distanceTy);
if (!plusEqual) {
clang::QualType returnTy = instantiated->getReturnType();
auto clangMinus = cast<clang::FunctionDecl>(minus->getClangDecl());
auto lhsTy = clangCtx.getRecordType(clangDecl);
auto rhsTy = clangMinus->getReturnType();
synthesizeCXXOperator(impl, clangDecl,
clang::BinaryOperatorKind::BO_AddAssign, lhsTy,
rhsTy, returnTy);
plusEqual = getPlusEqualOperator(decl, distanceTy);
}
}
}
if (!plusEqual)
return;

Expand Down
11 changes: 11 additions & 0 deletions lib/ClangImporter/ClangImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2796,6 +2796,12 @@ static bool isVisibleFromModule(const ClangModuleUnit *ModuleFilter,
if (OwningClangModule == ModuleFilter->getClangModule())
return true;

// If this decl was implicitly synthesized by the compiler, and is not
// supposed to be owned by any module, return true.
if (Importer->isSynthesizedAndVisibleFromAllModules(D)) {
return true;
}

// Friends from class templates don't have an owning module. Just return true.
if (isa<clang::FunctionDecl>(D) &&
cast<clang::FunctionDecl>(D)->isThisDeclarationInstantiatedFromAFriendDefinition())
Expand Down Expand Up @@ -6256,6 +6262,11 @@ FuncDecl *ClangImporter::getCXXSynthesizedOperatorFunc(FuncDecl *decl) {
return cast<FuncDecl>(synthesizedOperator);
}

bool ClangImporter::isSynthesizedAndVisibleFromAllModules(
const clang::Decl *decl) {
return Impl.synthesizedAndAlwaysVisibleDecls.contains(decl);
}

bool ClangImporter::isCXXMethodMutating(const clang::CXXMethodDecl *method) {
if (isa<clang::CXXConstructorDecl>(method) || !method->isConst())
return true;
Expand Down
2 changes: 2 additions & 0 deletions lib/ClangImporter/ImporterImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ class LLVM_LIBRARY_VISIBILITY ClangImporter::Implementation
llvm::MapVector<std::pair<NominalTypeDecl *, Type>,
std::pair<FuncDecl *, FuncDecl *>> cxxSubscripts;

llvm::SmallPtrSet<const clang::Decl *, 1> synthesizedAndAlwaysVisibleDecls;

private:
// Keep track of the decls that were already cloned for this specific class.
llvm::DenseMap<std::pair<ValueDecl *, DeclContext *>, ValueDecl *>
Expand Down
34 changes: 34 additions & 0 deletions test/Interop/Cxx/stdlib/overlay/Inputs/custom-collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,38 @@ struct SimpleCollectionReadOnly {
const int& operator[](int index) const { return x[index]; }
};

template <typename T>
struct HasInheritedTemplatedConstRACIterator {
public:
typedef InheritedTemplatedConstRACIterator<int> iterator;

private:
iterator b = iterator(1);
iterator e = iterator(6);

public:
iterator begin() const { return b; }
iterator end() const { return e; }
};

typedef HasInheritedTemplatedConstRACIterator<int>
HasInheritedTemplatedConstRACIteratorInt;

template <typename T>
struct HasInheritedTemplatedConstRACIteratorOutOfLineOps {
public:
typedef InheritedTemplatedConstRACIteratorOutOfLineOps<int> iterator;

private:
iterator b = iterator(1);
iterator e = iterator(4);

public:
iterator begin() const { return b; }
iterator end() const { return e; }
};

typedef HasInheritedTemplatedConstRACIteratorOutOfLineOps<int>
HasInheritedTemplatedConstRACIteratorOutOfLineOpsInt;

#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_COLLECTION_H
Loading