Skip to content

Commit d8f9197

Browse files
committed
[cxx-interop] [cxx-interop] Infer SWIFT_SHARED_REFERENCE for types inheriting from a C++ foreign reference type
rdar://97914474
1 parent c42268e commit d8f9197

18 files changed

+810
-102
lines changed

include/swift/AST/DiagnosticsClangImporter.def

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,9 @@ ERROR(returns_retained_or_returns_unretained_for_non_cxx_frt_values, none,
273273
"a SWIFT_SHARED_REFERENCE type",
274274
(const clang::NamedDecl *))
275275

276-
// TODO: make this case an error in next cxx-interop versions rdar://138806722
276+
// TODO: In the next C++ interop version, convert this warning into an error and
277+
// stop importing unannotated C++ APIs that return SWIFT_SHARED_REFERENCE.
278+
// rdar://138806722
277279
WARNING(no_returns_retained_returns_unretained, none,
278280
"%0 should be annotated with either SWIFT_RETURNS_RETAINED or "
279281
"SWIFT_RETURNS_UNRETAINED as it is returning a SWIFT_SHARED_REFERENCE",
@@ -286,6 +288,15 @@ WARNING(returns_retained_returns_unretained_on_overloaded_operator, none,
286288
"SWIFT_SHARED_REFERENCE types as owned ",
287289
(const clang::NamedDecl *))
288290

291+
// TODO: In the next C++ interop version, convert this warning into an error and
292+
// stop importing C++ types that inherit from SWIFT_SHARED_REFERENCE if the
293+
// Swift compiler cannot find unique retain/release functions.
294+
// rdar://145194375
295+
WARNING(cant_infer_frt_in_cxx_inheritance, none,
296+
"unable to infer SWIFT_SHARED_REFERENCE for %0, although one of its "
297+
"transitive base types is marked as SWIFT_SHARED_REFERENCE",
298+
(const clang::NamedDecl *))
299+
289300
NOTE(unsupported_builtin_type, none, "built-in type '%0' not supported", (StringRef))
290301
NOTE(record_field_not_imported, none, "field %0 unavailable (cannot import)", (const clang::NamedDecl*))
291302
NOTE(invoked_func_not_imported, none, "function %0 unavailable (cannot import)", (const clang::NamedDecl*))

include/swift/ClangImporter/ClangImporterRequests.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,17 @@ enum class CxxRecordSemanticsKind {
342342
struct CxxRecordSemanticsDescriptor final {
343343
const clang::RecordDecl *decl;
344344
ASTContext &ctx;
345+
ClangImporter::Implementation *importerImpl;
345346

346347
/// Whether to emit warnings for missing destructor or copy constructor
347348
/// whenever the classification of the type assumes that they exist (e.g. for
348349
/// a value type).
349350
bool shouldDiagnoseLifetimeOperations;
350351

351352
CxxRecordSemanticsDescriptor(const clang::RecordDecl *decl, ASTContext &ctx,
353+
ClangImporter::Implementation *importerImpl,
352354
bool shouldDiagnoseLifetimeOperations = true)
353-
: decl(decl), ctx(ctx),
355+
: decl(decl), ctx(ctx), importerImpl(importerImpl),
354356
shouldDiagnoseLifetimeOperations(shouldDiagnoseLifetimeOperations) {}
355357

356358
friend llvm::hash_code hash_value(const CxxRecordSemanticsDescriptor &desc) {

lib/AST/Decl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6602,9 +6602,11 @@ bool ClassDecl::isForeignReferenceType() const {
66026602
if (!clangRecordDecl)
66036603
return false;
66046604

6605+
// `importerImpl` is set to nullptr here to avoid diagnostics during this
6606+
// CxxRecordSemantics evaluation.
66056607
CxxRecordSemanticsKind kind = evaluateOrDefault(
66066608
getASTContext().evaluator,
6607-
CxxRecordSemantics({clangRecordDecl, getASTContext()}), {});
6609+
CxxRecordSemantics({clangRecordDecl, getASTContext(), nullptr}), {});
66086610
return kind == CxxRecordSemanticsKind::Reference;
66096611
}
66106612

lib/ClangImporter/ClangClassTemplateNamePrinter.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "ClangClassTemplateNamePrinter.h"
1414
#include "ImporterImpl.h"
15+
#include "swift/ClangImporter/ClangImporter.h"
1516
#include "clang/AST/TemplateArgumentVisitor.h"
1617
#include "clang/AST/TypeVisitor.h"
1718

@@ -24,10 +25,14 @@ struct TemplateInstantiationNamePrinter
2425
NameImporter *nameImporter;
2526
ImportNameVersion version;
2627

28+
ClangImporter::Implementation *importerImpl;
29+
2730
TemplateInstantiationNamePrinter(ASTContext &swiftCtx,
2831
NameImporter *nameImporter,
29-
ImportNameVersion version)
30-
: swiftCtx(swiftCtx), nameImporter(nameImporter), version(version) {}
32+
ImportNameVersion version,
33+
ClangImporter::Implementation *importerImpl)
34+
: swiftCtx(swiftCtx), nameImporter(nameImporter), version(version),
35+
importerImpl(importerImpl) {}
3136

3237
std::string VisitType(const clang::Type *type) {
3338
// Print "_" as a fallback if we couldn't emit a more meaningful type name.
@@ -100,9 +105,7 @@ struct TemplateInstantiationNamePrinter
100105
bool isReferenceType = false;
101106
if (auto tagDecl = type->getPointeeType()->getAsTagDecl()) {
102107
if (auto *rd = dyn_cast<clang::RecordDecl>(tagDecl))
103-
isReferenceType =
104-
ClangImporter::Implementation::recordHasReferenceSemantics(
105-
rd, swiftCtx);
108+
isReferenceType = recordHasReferenceSemantics(rd, importerImpl);
106109
}
107110

108111
TagTypeDecorator decorator;
@@ -167,8 +170,9 @@ struct TemplateArgumentPrinter
167170
TemplateInstantiationNamePrinter typePrinter;
168171

169172
TemplateArgumentPrinter(ASTContext &swiftCtx, NameImporter *nameImporter,
170-
ImportNameVersion version)
171-
: typePrinter(swiftCtx, nameImporter, version) {}
173+
ImportNameVersion version,
174+
ClangImporter::Implementation *importerImpl)
175+
: typePrinter(swiftCtx, nameImporter, version, importerImpl) {}
172176

173177
void VisitTemplateArgument(const clang::TemplateArgument &arg,
174178
llvm::raw_svector_ostream &buffer) {
@@ -219,7 +223,8 @@ struct TemplateArgumentPrinter
219223
std::string swift::importer::printClassTemplateSpecializationName(
220224
const clang::ClassTemplateSpecializationDecl *decl, ASTContext &swiftCtx,
221225
NameImporter *nameImporter, ImportNameVersion version) {
222-
TemplateArgumentPrinter templateArgPrinter(swiftCtx, nameImporter, version);
226+
TemplateArgumentPrinter templateArgPrinter(swiftCtx, nameImporter, version,
227+
nameImporter->getImporterImpl());
223228

224229
llvm::SmallString<128> storage;
225230
llvm::raw_svector_ostream buffer(storage);

lib/ClangImporter/ClangImporter.cpp

Lines changed: 161 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,10 +1406,10 @@ ClangImporter::create(ASTContext &ctx,
14061406
// Install a Clang module file extension to build Swift name lookup tables.
14071407
importer->Impl.Invocation->getFrontendOpts().ModuleFileExtensions.push_back(
14081408
std::make_shared<SwiftNameLookupExtension>(
1409-
importer->Impl.BridgingHeaderLookupTable,
1410-
importer->Impl.LookupTables, importer->Impl.SwiftContext,
1409+
importer->Impl.BridgingHeaderLookupTable, importer->Impl.LookupTables,
1410+
importer->Impl.SwiftContext,
14111411
importer->Impl.getBufferImporterForDiagnostics(),
1412-
importer->Impl.platformAvailability));
1412+
importer->Impl.platformAvailability, &importer->Impl));
14131413

14141414
// Create a compiler instance.
14151415
{
@@ -1557,7 +1557,7 @@ ClangImporter::create(ASTContext &ctx,
15571557

15581558
importer->Impl.nameImporter.reset(new NameImporter(
15591559
importer->Impl.SwiftContext, importer->Impl.platformAvailability,
1560-
importer->Impl.getClangSema()));
1560+
importer->Impl.getClangSema(), &importer->Impl));
15611561

15621562
// FIXME: These decls are not being parsed correctly since (a) some of the
15631563
// callbacks are still being added, and (b) the logic to parse them has
@@ -7678,7 +7678,146 @@ bool importer::isForeignReferenceTypeWithoutImmortalAttrs(const clang::QualType
76787678
!hasImmortalAtts(pointeeType->getDecl());
76797679
}
76807680

7681+
static bool hasDiamondInheritanceRefType(const clang::CXXRecordDecl *decl) {
7682+
if (!decl->hasDefinition() || decl->isDependentType())
7683+
return false;
7684+
7685+
llvm::DenseSet<const clang::CXXRecordDecl *> seenBases;
7686+
bool hasRefDiamond = false;
7687+
7688+
decl->forallBases([&](const clang::CXXRecordDecl *Base) {
7689+
if (hasImportAsRefAttr(Base) && !seenBases.insert(Base).second &&
7690+
!decl->isVirtuallyDerivedFrom(Base))
7691+
hasRefDiamond = true;
7692+
return true;
7693+
});
7694+
7695+
return hasRefDiamond;
7696+
}
7697+
7698+
// Returns the given declaration along with all its parent declarations that are
7699+
// reference types.
7700+
static llvm::SmallVector<const clang::RecordDecl *, 4>
7701+
getRefParentDecls(const clang::RecordDecl *decl, ASTContext &ctx,
7702+
ClangImporter::Implementation *importerImpl) {
7703+
assert(decl && "decl is null inside getRefParentDecls");
7704+
7705+
llvm::SmallVector<const clang::RecordDecl *, 4> matchingDecls;
7706+
7707+
if (hasImportAsRefAttr(decl))
7708+
matchingDecls.push_back(decl);
7709+
7710+
if (const auto *cxxRecordDecl = llvm::dyn_cast<clang::CXXRecordDecl>(decl)) {
7711+
if (!cxxRecordDecl->hasDefinition())
7712+
return matchingDecls;
7713+
if (hasDiamondInheritanceRefType(cxxRecordDecl)) {
7714+
if (importerImpl) {
7715+
if (!importerImpl->DiagnosedCxxRefDecls.count(decl)) {
7716+
HeaderLoc loc(decl->getLocation());
7717+
importerImpl->diagnose(loc, diag::cant_infer_frt_in_cxx_inheritance,
7718+
decl);
7719+
importerImpl->DiagnosedCxxRefDecls.insert(decl);
7720+
}
7721+
} else {
7722+
ctx.Diags.diagnose({}, diag::cant_infer_frt_in_cxx_inheritance, decl);
7723+
assert(false && "nullpointer passeed for importerImpl when calling "
7724+
"getRefParentOrDiag");
7725+
}
7726+
return matchingDecls;
7727+
}
7728+
cxxRecordDecl->forallBases([&](const clang::CXXRecordDecl *baseDecl) {
7729+
if (hasImportAsRefAttr(baseDecl))
7730+
matchingDecls.push_back(baseDecl);
7731+
return true;
7732+
});
7733+
}
7734+
7735+
return matchingDecls;
7736+
}
7737+
7738+
static llvm::SmallVector<ValueDecl *, 1>
7739+
getValueDeclsForName(const clang::Decl *decl, ASTContext &ctx, StringRef name) {
7740+
llvm::SmallVector<ValueDecl *, 1> results;
7741+
auto *clangMod = decl->getOwningModule();
7742+
if (clangMod && clangMod->isSubModule())
7743+
clangMod = clangMod->getTopLevelModule();
7744+
if (clangMod) {
7745+
auto parentModule =
7746+
ctx.getClangModuleLoader()->getWrapperForModule(clangMod);
7747+
ctx.lookupInModule(parentModule, name, results);
7748+
} else {
7749+
// There is no Clang module for this declaration, so perform lookup from
7750+
// the main module. This will find declarations from the bridging header.
7751+
namelookup::lookupInModule(
7752+
ctx.MainModule, ctx.getIdentifier(name), results,
7753+
NLKind::UnqualifiedLookup, namelookup::ResolutionKind::Overloadable,
7754+
ctx.MainModule, SourceLoc(), NL_UnqualifiedDefault);
7755+
7756+
// Filter out any declarations that didn't come from Clang.
7757+
auto newEnd =
7758+
std::remove_if(results.begin(), results.end(),
7759+
[&](ValueDecl *decl) { return !decl->getClangDecl(); });
7760+
results.erase(newEnd, results.end());
7761+
}
7762+
return results;
7763+
}
7764+
7765+
static const clang::RecordDecl *
7766+
getRefParentOrDiag(const clang::RecordDecl *decl, ASTContext &ctx,
7767+
ClangImporter::Implementation *importerImpl) {
7768+
auto refParentDecls = getRefParentDecls(decl, ctx, importerImpl);
7769+
if (refParentDecls.empty())
7770+
return nullptr;
7771+
7772+
std::unordered_set<ValueDecl *> uniqueRetainDecls{}, uniqueReleaseDecls{};
7773+
constexpr StringRef retainPrefix = "retain:";
7774+
constexpr StringRef releasePrefix = "release:";
7775+
7776+
for (const auto *refParentDecl : refParentDecls) {
7777+
assert(refParentDecl && "refParentDecl is null inside getRefParentOrDiag");
7778+
for (const auto *attr : refParentDecl->getAttrs()) {
7779+
if (const auto swiftAttr = llvm::dyn_cast<clang::SwiftAttrAttr>(attr)) {
7780+
const auto &attribute = swiftAttr->getAttribute();
7781+
llvm::SmallVector<ValueDecl *, 1> valueDecls;
7782+
if (attribute.starts_with(retainPrefix)) {
7783+
auto name = attribute.drop_front(retainPrefix.size()).str();
7784+
valueDecls = getValueDeclsForName(decl, ctx, name);
7785+
uniqueRetainDecls.insert(valueDecls.begin(), valueDecls.end());
7786+
} else if (attribute.starts_with(releasePrefix)) {
7787+
auto name = attribute.drop_front(releasePrefix.size()).str();
7788+
valueDecls = getValueDeclsForName(decl, ctx, name);
7789+
uniqueReleaseDecls.insert(valueDecls.begin(), valueDecls.end());
7790+
}
7791+
}
7792+
}
7793+
}
7794+
7795+
// Ensure that exactly one unique retain function and one unique release
7796+
// function are found.
7797+
if (uniqueRetainDecls.size() != 1 || uniqueReleaseDecls.size() != 1) {
7798+
if (importerImpl) {
7799+
if (!importerImpl->DiagnosedCxxRefDecls.count(decl)) {
7800+
HeaderLoc loc(decl->getLocation());
7801+
importerImpl->diagnose(loc, diag::cant_infer_frt_in_cxx_inheritance,
7802+
decl);
7803+
importerImpl->DiagnosedCxxRefDecls.insert(decl);
7804+
}
7805+
} else {
7806+
ctx.Diags.diagnose({}, diag::cant_infer_frt_in_cxx_inheritance, decl);
7807+
assert(false && "nullpointer passed for importerImpl when calling "
7808+
"getRefParentOrDiag");
7809+
}
7810+
return nullptr;
7811+
}
7812+
7813+
return refParentDecls.front();
7814+
}
7815+
76817816
// Is this a pointer to a foreign reference type.
7817+
// TODO: We need to review functions like this to ensure that
7818+
// CxxRecordSemantics::evaluate is consistently invoked wherever we need to
7819+
// determine whether a C++ type qualifies as a foreign reference type
7820+
// rdar://145184659
76827821
static bool isForeignReferenceType(const clang::QualType type) {
76837822
if (!type->isPointerType())
76847823
return false;
@@ -7927,10 +8066,10 @@ CxxRecordSemanticsKind
79278066
CxxRecordSemantics::evaluate(Evaluator &evaluator,
79288067
CxxRecordSemanticsDescriptor desc) const {
79298068
const auto *decl = desc.decl;
7930-
7931-
if (hasImportAsRefAttr(decl)) {
8069+
ClangImporter::Implementation *importerImpl = desc.importerImpl;
8070+
if (hasImportAsRefAttr(decl) ||
8071+
getRefParentOrDiag(decl, desc.ctx, importerImpl))
79328072
return CxxRecordSemanticsKind::Reference;
7933-
}
79348073

79358074
auto cxxDecl = dyn_cast<clang::CXXRecordDecl>(decl);
79368075
if (!cxxDecl) {
@@ -7943,15 +8082,16 @@ CxxRecordSemantics::evaluate(Evaluator &evaluator,
79438082
if (!hasDestroyTypeOperations(cxxDecl) ||
79448083
(!hasCopyTypeOperations(cxxDecl) && !hasMoveTypeOperations(cxxDecl))) {
79458084
if (desc.shouldDiagnoseLifetimeOperations) {
8085+
HeaderLoc loc(decl->getLocation());
79468086
if (hasUnsafeAPIAttr(cxxDecl))
7947-
desc.ctx.Diags.diagnose({}, diag::api_pattern_attr_ignored,
7948-
"import_unsafe", decl->getNameAsString());
8087+
importerImpl->diagnose(loc, diag::api_pattern_attr_ignored,
8088+
"import_unsafe", decl->getNameAsString());
79498089
if (hasOwnedValueAttr(cxxDecl))
7950-
desc.ctx.Diags.diagnose({}, diag::api_pattern_attr_ignored,
7951-
"import_owned", decl->getNameAsString());
8090+
importerImpl->diagnose(loc, diag::api_pattern_attr_ignored,
8091+
"import_owned", decl->getNameAsString());
79528092
if (hasIteratorAPIAttr(cxxDecl))
7953-
desc.ctx.Diags.diagnose({}, diag::api_pattern_attr_ignored,
7954-
"import_iterator", decl->getNameAsString());
8093+
importerImpl->diagnose(loc, diag::api_pattern_attr_ignored,
8094+
"import_iterator", decl->getNameAsString());
79558095
}
79568096

79578097
return CxxRecordSemanticsKind::MissingLifetimeOperation;
@@ -8158,6 +8298,12 @@ CustomRefCountingOperationResult CustomRefCountingOperation::evaluate(
81588298
: "release:";
81598299

81608300
auto decl = cast<clang::RecordDecl>(swiftDecl->getClangDecl());
8301+
8302+
if (!hasImportAsRefAttr(decl)) {
8303+
if (auto parentRefDecl = getRefParentOrDiag(decl, ctx, nullptr))
8304+
decl = parentRefDecl;
8305+
}
8306+
81618307
if (!decl->hasAttrs())
81628308
return {CustomRefCountingOperationResult::noAttribute, nullptr, ""};
81638309

@@ -8184,27 +8330,8 @@ CustomRefCountingOperationResult CustomRefCountingOperation::evaluate(
81848330
if (name == "immortal")
81858331
return {CustomRefCountingOperationResult::immortal, nullptr, name};
81868332

8187-
llvm::SmallVector<ValueDecl *, 1> results;
8188-
auto *clangMod = swiftDecl->getClangDecl()->getOwningModule();
8189-
if (clangMod && clangMod->isSubModule())
8190-
clangMod = clangMod->getTopLevelModule();
8191-
if (clangMod) {
8192-
auto parentModule = ctx.getClangModuleLoader()->getWrapperForModule(clangMod);
8193-
ctx.lookupInModule(parentModule, name, results);
8194-
} else {
8195-
// There is no Clang module for this declaration, so perform lookup from
8196-
// the main module. This will find declarations from the bridging header.
8197-
namelookup::lookupInModule(
8198-
ctx.MainModule, ctx.getIdentifier(name), results,
8199-
NLKind::UnqualifiedLookup, namelookup::ResolutionKind::Overloadable,
8200-
ctx.MainModule, SourceLoc(), NL_UnqualifiedDefault);
8201-
8202-
// Filter out any declarations that didn't come from Clang.
8203-
auto newEnd = std::remove_if(results.begin(), results.end(), [&](ValueDecl *decl) {
8204-
return !decl->getClangDecl();
8205-
});
8206-
results.erase(newEnd, results.end());
8207-
}
8333+
llvm::SmallVector<ValueDecl *, 1> results =
8334+
getValueDeclsForName(swiftDecl->getClangDecl(), ctx, name);
82088335
if (results.size() == 1)
82098336
return {CustomRefCountingOperationResult::foundOperation, results.front(),
82108337
name};

0 commit comments

Comments
 (0)