Skip to content

Commit 9782eae

Browse files
authored
Merge pull request #77522 from fahadnayyar/cxx-frt-inheritance-diagnostics
[cxx-interop] Infer SWIFT_SHARED_REFERENCE for types inheriting from a C++ foreign reference type
2 parents b1fb62d + d8f9197 commit 9782eae

18 files changed

+810
-102
lines changed

include/swift/AST/DiagnosticsClangImporter.def

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ ERROR(returns_retained_or_returns_unretained_for_non_cxx_frt_values, none,
295295
"a SWIFT_SHARED_REFERENCE type",
296296
(const clang::NamedDecl *))
297297

298-
// TODO: make this case an error in next cxx-interop versions rdar://138806722
298+
// TODO: In the next C++ interop version, convert this warning into an error and
299+
// stop importing unannotated C++ APIs that return SWIFT_SHARED_REFERENCE.
300+
// rdar://138806722
299301
WARNING(no_returns_retained_returns_unretained, none,
300302
"%0 should be annotated with either SWIFT_RETURNS_RETAINED or "
301303
"SWIFT_RETURNS_UNRETAINED as it is returning a SWIFT_SHARED_REFERENCE",
@@ -308,6 +310,15 @@ WARNING(returns_retained_returns_unretained_on_overloaded_operator, none,
308310
"SWIFT_SHARED_REFERENCE types as owned ",
309311
(const clang::NamedDecl *))
310312

313+
// TODO: In the next C++ interop version, convert this warning into an error and
314+
// stop importing C++ types that inherit from SWIFT_SHARED_REFERENCE if the
315+
// Swift compiler cannot find unique retain/release functions.
316+
// rdar://145194375
317+
WARNING(cant_infer_frt_in_cxx_inheritance, none,
318+
"unable to infer SWIFT_SHARED_REFERENCE for %0, although one of its "
319+
"transitive base types is marked as SWIFT_SHARED_REFERENCE",
320+
(const clang::NamedDecl *))
321+
311322
NOTE(unsupported_builtin_type, none, "built-in type '%0' not supported", (StringRef))
312323
NOTE(record_field_not_imported, none, "field %0 unavailable (cannot import)", (const clang::NamedDecl*))
313324
NOTE(invoked_func_not_imported, none, "function %0 unavailable (cannot import)", (const clang::NamedDecl*))

include/swift/ClangImporter/ClangImporterRequests.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,15 +343,17 @@ enum class CxxRecordSemanticsKind {
343343
struct CxxRecordSemanticsDescriptor final {
344344
const clang::RecordDecl *decl;
345345
ASTContext &ctx;
346+
ClangImporter::Implementation *importerImpl;
346347

347348
/// Whether to emit warnings for missing destructor or copy constructor
348349
/// whenever the classification of the type assumes that they exist (e.g. for
349350
/// a value type).
350351
bool shouldDiagnoseLifetimeOperations;
351352

352353
CxxRecordSemanticsDescriptor(const clang::RecordDecl *decl, ASTContext &ctx,
354+
ClangImporter::Implementation *importerImpl,
353355
bool shouldDiagnoseLifetimeOperations = true)
354-
: decl(decl), ctx(ctx),
356+
: decl(decl), ctx(ctx), importerImpl(importerImpl),
355357
shouldDiagnoseLifetimeOperations(shouldDiagnoseLifetimeOperations) {}
356358

357359
friend llvm::hash_code hash_value(const CxxRecordSemanticsDescriptor &desc) {

lib/AST/Decl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6608,9 +6608,11 @@ bool ClassDecl::isForeignReferenceType() const {
66086608
if (!clangRecordDecl)
66096609
return false;
66106610

6611+
// `importerImpl` is set to nullptr here to avoid diagnostics during this
6612+
// CxxRecordSemantics evaluation.
66116613
CxxRecordSemanticsKind kind = evaluateOrDefault(
66126614
getASTContext().evaluator,
6613-
CxxRecordSemantics({clangRecordDecl, getASTContext()}), {});
6615+
CxxRecordSemantics({clangRecordDecl, getASTContext(), nullptr}), {});
66146616
return kind == CxxRecordSemanticsKind::Reference;
66156617
}
66166618

lib/ClangImporter/ClangClassTemplateNamePrinter.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "ClangClassTemplateNamePrinter.h"
1414
#include "ImporterImpl.h"
15+
#include "swift/ClangImporter/ClangImporter.h"
1516
#include "clang/AST/TemplateArgumentVisitor.h"
1617
#include "clang/AST/TypeVisitor.h"
1718

@@ -24,10 +25,14 @@ struct TemplateInstantiationNamePrinter
2425
NameImporter *nameImporter;
2526
ImportNameVersion version;
2627

28+
ClangImporter::Implementation *importerImpl;
29+
2730
TemplateInstantiationNamePrinter(ASTContext &swiftCtx,
2831
NameImporter *nameImporter,
29-
ImportNameVersion version)
30-
: swiftCtx(swiftCtx), nameImporter(nameImporter), version(version) {}
32+
ImportNameVersion version,
33+
ClangImporter::Implementation *importerImpl)
34+
: swiftCtx(swiftCtx), nameImporter(nameImporter), version(version),
35+
importerImpl(importerImpl) {}
3136

3237
std::string VisitType(const clang::Type *type) {
3338
// Print "_" as a fallback if we couldn't emit a more meaningful type name.
@@ -100,9 +105,7 @@ struct TemplateInstantiationNamePrinter
100105
bool isReferenceType = false;
101106
if (auto tagDecl = type->getPointeeType()->getAsTagDecl()) {
102107
if (auto *rd = dyn_cast<clang::RecordDecl>(tagDecl))
103-
isReferenceType =
104-
ClangImporter::Implementation::recordHasReferenceSemantics(
105-
rd, swiftCtx);
108+
isReferenceType = recordHasReferenceSemantics(rd, importerImpl);
106109
}
107110

108111
TagTypeDecorator decorator;
@@ -167,8 +170,9 @@ struct TemplateArgumentPrinter
167170
TemplateInstantiationNamePrinter typePrinter;
168171

169172
TemplateArgumentPrinter(ASTContext &swiftCtx, NameImporter *nameImporter,
170-
ImportNameVersion version)
171-
: typePrinter(swiftCtx, nameImporter, version) {}
173+
ImportNameVersion version,
174+
ClangImporter::Implementation *importerImpl)
175+
: typePrinter(swiftCtx, nameImporter, version, importerImpl) {}
172176

173177
void VisitTemplateArgument(const clang::TemplateArgument &arg,
174178
llvm::raw_svector_ostream &buffer) {
@@ -219,7 +223,8 @@ struct TemplateArgumentPrinter
219223
std::string swift::importer::printClassTemplateSpecializationName(
220224
const clang::ClassTemplateSpecializationDecl *decl, ASTContext &swiftCtx,
221225
NameImporter *nameImporter, ImportNameVersion version) {
222-
TemplateArgumentPrinter templateArgPrinter(swiftCtx, nameImporter, version);
226+
TemplateArgumentPrinter templateArgPrinter(swiftCtx, nameImporter, version,
227+
nameImporter->getImporterImpl());
223228

224229
llvm::SmallString<128> storage;
225230
llvm::raw_svector_ostream buffer(storage);

lib/ClangImporter/ClangImporter.cpp

Lines changed: 161 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,10 +1406,10 @@ ClangImporter::create(ASTContext &ctx,
14061406
// Install a Clang module file extension to build Swift name lookup tables.
14071407
importer->Impl.Invocation->getFrontendOpts().ModuleFileExtensions.push_back(
14081408
std::make_shared<SwiftNameLookupExtension>(
1409-
importer->Impl.BridgingHeaderLookupTable,
1410-
importer->Impl.LookupTables, importer->Impl.SwiftContext,
1409+
importer->Impl.BridgingHeaderLookupTable, importer->Impl.LookupTables,
1410+
importer->Impl.SwiftContext,
14111411
importer->Impl.getBufferImporterForDiagnostics(),
1412-
importer->Impl.platformAvailability));
1412+
importer->Impl.platformAvailability, &importer->Impl));
14131413

14141414
// Create a compiler instance.
14151415
{
@@ -1557,7 +1557,7 @@ ClangImporter::create(ASTContext &ctx,
15571557

15581558
importer->Impl.nameImporter.reset(new NameImporter(
15591559
importer->Impl.SwiftContext, importer->Impl.platformAvailability,
1560-
importer->Impl.getClangSema()));
1560+
importer->Impl.getClangSema(), &importer->Impl));
15611561

15621562
// FIXME: These decls are not being parsed correctly since (a) some of the
15631563
// callbacks are still being added, and (b) the logic to parse them has
@@ -7680,7 +7680,146 @@ bool importer::isForeignReferenceTypeWithoutImmortalAttrs(const clang::QualType
76807680
!hasImmortalAtts(pointeeType->getDecl());
76817681
}
76827682

7683+
static bool hasDiamondInheritanceRefType(const clang::CXXRecordDecl *decl) {
7684+
if (!decl->hasDefinition() || decl->isDependentType())
7685+
return false;
7686+
7687+
llvm::DenseSet<const clang::CXXRecordDecl *> seenBases;
7688+
bool hasRefDiamond = false;
7689+
7690+
decl->forallBases([&](const clang::CXXRecordDecl *Base) {
7691+
if (hasImportAsRefAttr(Base) && !seenBases.insert(Base).second &&
7692+
!decl->isVirtuallyDerivedFrom(Base))
7693+
hasRefDiamond = true;
7694+
return true;
7695+
});
7696+
7697+
return hasRefDiamond;
7698+
}
7699+
7700+
// Returns the given declaration along with all its parent declarations that are
7701+
// reference types.
7702+
static llvm::SmallVector<const clang::RecordDecl *, 4>
7703+
getRefParentDecls(const clang::RecordDecl *decl, ASTContext &ctx,
7704+
ClangImporter::Implementation *importerImpl) {
7705+
assert(decl && "decl is null inside getRefParentDecls");
7706+
7707+
llvm::SmallVector<const clang::RecordDecl *, 4> matchingDecls;
7708+
7709+
if (hasImportAsRefAttr(decl))
7710+
matchingDecls.push_back(decl);
7711+
7712+
if (const auto *cxxRecordDecl = llvm::dyn_cast<clang::CXXRecordDecl>(decl)) {
7713+
if (!cxxRecordDecl->hasDefinition())
7714+
return matchingDecls;
7715+
if (hasDiamondInheritanceRefType(cxxRecordDecl)) {
7716+
if (importerImpl) {
7717+
if (!importerImpl->DiagnosedCxxRefDecls.count(decl)) {
7718+
HeaderLoc loc(decl->getLocation());
7719+
importerImpl->diagnose(loc, diag::cant_infer_frt_in_cxx_inheritance,
7720+
decl);
7721+
importerImpl->DiagnosedCxxRefDecls.insert(decl);
7722+
}
7723+
} else {
7724+
ctx.Diags.diagnose({}, diag::cant_infer_frt_in_cxx_inheritance, decl);
7725+
assert(false && "nullpointer passeed for importerImpl when calling "
7726+
"getRefParentOrDiag");
7727+
}
7728+
return matchingDecls;
7729+
}
7730+
cxxRecordDecl->forallBases([&](const clang::CXXRecordDecl *baseDecl) {
7731+
if (hasImportAsRefAttr(baseDecl))
7732+
matchingDecls.push_back(baseDecl);
7733+
return true;
7734+
});
7735+
}
7736+
7737+
return matchingDecls;
7738+
}
7739+
7740+
static llvm::SmallVector<ValueDecl *, 1>
7741+
getValueDeclsForName(const clang::Decl *decl, ASTContext &ctx, StringRef name) {
7742+
llvm::SmallVector<ValueDecl *, 1> results;
7743+
auto *clangMod = decl->getOwningModule();
7744+
if (clangMod && clangMod->isSubModule())
7745+
clangMod = clangMod->getTopLevelModule();
7746+
if (clangMod) {
7747+
auto parentModule =
7748+
ctx.getClangModuleLoader()->getWrapperForModule(clangMod);
7749+
ctx.lookupInModule(parentModule, name, results);
7750+
} else {
7751+
// There is no Clang module for this declaration, so perform lookup from
7752+
// the main module. This will find declarations from the bridging header.
7753+
namelookup::lookupInModule(
7754+
ctx.MainModule, ctx.getIdentifier(name), results,
7755+
NLKind::UnqualifiedLookup, namelookup::ResolutionKind::Overloadable,
7756+
ctx.MainModule, SourceLoc(), NL_UnqualifiedDefault);
7757+
7758+
// Filter out any declarations that didn't come from Clang.
7759+
auto newEnd =
7760+
std::remove_if(results.begin(), results.end(),
7761+
[&](ValueDecl *decl) { return !decl->getClangDecl(); });
7762+
results.erase(newEnd, results.end());
7763+
}
7764+
return results;
7765+
}
7766+
7767+
static const clang::RecordDecl *
7768+
getRefParentOrDiag(const clang::RecordDecl *decl, ASTContext &ctx,
7769+
ClangImporter::Implementation *importerImpl) {
7770+
auto refParentDecls = getRefParentDecls(decl, ctx, importerImpl);
7771+
if (refParentDecls.empty())
7772+
return nullptr;
7773+
7774+
std::unordered_set<ValueDecl *> uniqueRetainDecls{}, uniqueReleaseDecls{};
7775+
constexpr StringRef retainPrefix = "retain:";
7776+
constexpr StringRef releasePrefix = "release:";
7777+
7778+
for (const auto *refParentDecl : refParentDecls) {
7779+
assert(refParentDecl && "refParentDecl is null inside getRefParentOrDiag");
7780+
for (const auto *attr : refParentDecl->getAttrs()) {
7781+
if (const auto swiftAttr = llvm::dyn_cast<clang::SwiftAttrAttr>(attr)) {
7782+
const auto &attribute = swiftAttr->getAttribute();
7783+
llvm::SmallVector<ValueDecl *, 1> valueDecls;
7784+
if (attribute.starts_with(retainPrefix)) {
7785+
auto name = attribute.drop_front(retainPrefix.size()).str();
7786+
valueDecls = getValueDeclsForName(decl, ctx, name);
7787+
uniqueRetainDecls.insert(valueDecls.begin(), valueDecls.end());
7788+
} else if (attribute.starts_with(releasePrefix)) {
7789+
auto name = attribute.drop_front(releasePrefix.size()).str();
7790+
valueDecls = getValueDeclsForName(decl, ctx, name);
7791+
uniqueReleaseDecls.insert(valueDecls.begin(), valueDecls.end());
7792+
}
7793+
}
7794+
}
7795+
}
7796+
7797+
// Ensure that exactly one unique retain function and one unique release
7798+
// function are found.
7799+
if (uniqueRetainDecls.size() != 1 || uniqueReleaseDecls.size() != 1) {
7800+
if (importerImpl) {
7801+
if (!importerImpl->DiagnosedCxxRefDecls.count(decl)) {
7802+
HeaderLoc loc(decl->getLocation());
7803+
importerImpl->diagnose(loc, diag::cant_infer_frt_in_cxx_inheritance,
7804+
decl);
7805+
importerImpl->DiagnosedCxxRefDecls.insert(decl);
7806+
}
7807+
} else {
7808+
ctx.Diags.diagnose({}, diag::cant_infer_frt_in_cxx_inheritance, decl);
7809+
assert(false && "nullpointer passed for importerImpl when calling "
7810+
"getRefParentOrDiag");
7811+
}
7812+
return nullptr;
7813+
}
7814+
7815+
return refParentDecls.front();
7816+
}
7817+
76837818
// Is this a pointer to a foreign reference type.
7819+
// TODO: We need to review functions like this to ensure that
7820+
// CxxRecordSemantics::evaluate is consistently invoked wherever we need to
7821+
// determine whether a C++ type qualifies as a foreign reference type
7822+
// rdar://145184659
76847823
static bool isForeignReferenceType(const clang::QualType type) {
76857824
if (!type->isPointerType())
76867825
return false;
@@ -7929,10 +8068,10 @@ CxxRecordSemanticsKind
79298068
CxxRecordSemantics::evaluate(Evaluator &evaluator,
79308069
CxxRecordSemanticsDescriptor desc) const {
79318070
const auto *decl = desc.decl;
7932-
7933-
if (hasImportAsRefAttr(decl)) {
8071+
ClangImporter::Implementation *importerImpl = desc.importerImpl;
8072+
if (hasImportAsRefAttr(decl) ||
8073+
getRefParentOrDiag(decl, desc.ctx, importerImpl))
79348074
return CxxRecordSemanticsKind::Reference;
7935-
}
79368075

79378076
auto cxxDecl = dyn_cast<clang::CXXRecordDecl>(decl);
79388077
if (!cxxDecl) {
@@ -7945,15 +8084,16 @@ CxxRecordSemantics::evaluate(Evaluator &evaluator,
79458084
if (!hasDestroyTypeOperations(cxxDecl) ||
79468085
(!hasCopyTypeOperations(cxxDecl) && !hasMoveTypeOperations(cxxDecl))) {
79478086
if (desc.shouldDiagnoseLifetimeOperations) {
8087+
HeaderLoc loc(decl->getLocation());
79488088
if (hasUnsafeAPIAttr(cxxDecl))
7949-
desc.ctx.Diags.diagnose({}, diag::api_pattern_attr_ignored,
7950-
"import_unsafe", decl->getNameAsString());
8089+
importerImpl->diagnose(loc, diag::api_pattern_attr_ignored,
8090+
"import_unsafe", decl->getNameAsString());
79518091
if (hasOwnedValueAttr(cxxDecl))
7952-
desc.ctx.Diags.diagnose({}, diag::api_pattern_attr_ignored,
7953-
"import_owned", decl->getNameAsString());
8092+
importerImpl->diagnose(loc, diag::api_pattern_attr_ignored,
8093+
"import_owned", decl->getNameAsString());
79548094
if (hasIteratorAPIAttr(cxxDecl))
7955-
desc.ctx.Diags.diagnose({}, diag::api_pattern_attr_ignored,
7956-
"import_iterator", decl->getNameAsString());
8095+
importerImpl->diagnose(loc, diag::api_pattern_attr_ignored,
8096+
"import_iterator", decl->getNameAsString());
79578097
}
79588098

79598099
return CxxRecordSemanticsKind::MissingLifetimeOperation;
@@ -8160,6 +8300,12 @@ CustomRefCountingOperationResult CustomRefCountingOperation::evaluate(
81608300
: "release:";
81618301

81628302
auto decl = cast<clang::RecordDecl>(swiftDecl->getClangDecl());
8303+
8304+
if (!hasImportAsRefAttr(decl)) {
8305+
if (auto parentRefDecl = getRefParentOrDiag(decl, ctx, nullptr))
8306+
decl = parentRefDecl;
8307+
}
8308+
81638309
if (!decl->hasAttrs())
81648310
return {CustomRefCountingOperationResult::noAttribute, nullptr, ""};
81658311

@@ -8186,27 +8332,8 @@ CustomRefCountingOperationResult CustomRefCountingOperation::evaluate(
81868332
if (name == "immortal")
81878333
return {CustomRefCountingOperationResult::immortal, nullptr, name};
81888334

8189-
llvm::SmallVector<ValueDecl *, 1> results;
8190-
auto *clangMod = swiftDecl->getClangDecl()->getOwningModule();
8191-
if (clangMod && clangMod->isSubModule())
8192-
clangMod = clangMod->getTopLevelModule();
8193-
if (clangMod) {
8194-
auto parentModule = ctx.getClangModuleLoader()->getWrapperForModule(clangMod);
8195-
ctx.lookupInModule(parentModule, name, results);
8196-
} else {
8197-
// There is no Clang module for this declaration, so perform lookup from
8198-
// the main module. This will find declarations from the bridging header.
8199-
namelookup::lookupInModule(
8200-
ctx.MainModule, ctx.getIdentifier(name), results,
8201-
NLKind::UnqualifiedLookup, namelookup::ResolutionKind::Overloadable,
8202-
ctx.MainModule, SourceLoc(), NL_UnqualifiedDefault);
8203-
8204-
// Filter out any declarations that didn't come from Clang.
8205-
auto newEnd = std::remove_if(results.begin(), results.end(), [&](ValueDecl *decl) {
8206-
return !decl->getClangDecl();
8207-
});
8208-
results.erase(newEnd, results.end());
8209-
}
8335+
llvm::SmallVector<ValueDecl *, 1> results =
8336+
getValueDeclsForName(swiftDecl->getClangDecl(), ctx, name);
82108337
if (results.size() == 1)
82118338
return {CustomRefCountingOperationResult::foundOperation, results.front(),
82128339
name};

0 commit comments

Comments
 (0)