Skip to content

Commit f7ce9aa

Browse files
committed
[cxx-interop] Synthesized derived-to-base field getter should copy out a retainable FRT value
This matches the semantics of accessing the same field from the base class
1 parent 4150450 commit f7ce9aa

File tree

7 files changed

+280
-31
lines changed

7 files changed

+280
-31
lines changed

lib/ClangImporter/ClangImporter.cpp

Lines changed: 97 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4846,6 +4846,11 @@ static clang::CXXMethodDecl *synthesizeCxxBaseMethod(
48464846
newMethod->setImplicit();
48474847
newMethod->setImplicitlyInline();
48484848
newMethod->setAccess(clang::AccessSpecifier::AS_public);
4849+
if (method->hasAttr<clang::CFReturnsRetainedAttr>()) {
4850+
// Return an FRT field at +1 if the base method also follows this
4851+
// convention.
4852+
newMethod->addAttr(clang::CFReturnsRetainedAttr::CreateImplicit(clangCtx));
4853+
}
48494854

48504855
llvm::SmallVector<clang::ParmVarDecl *, 4> params;
48514856
for (size_t i = 0; i < method->getNumParams(); ++i) {
@@ -5047,7 +5052,8 @@ synthesizeBaseClassMethodBody(AbstractFunctionDecl *afd, void *context) {
50475052
// to the base class while the field is accessed.
50485053
static clang::CXXMethodDecl *synthesizeCxxBaseGetterAccessorMethod(
50495054
ClangImporter &impl, const clang::CXXRecordDecl *derivedClass,
5050-
const clang::CXXRecordDecl *baseClass, const clang::FieldDecl *field) {
5055+
const clang::CXXRecordDecl *baseClass, const clang::FieldDecl *field,
5056+
ValueDecl *retainOperationFn) {
50515057
auto &clangCtx = impl.getClangASTContext();
50525058
auto &clangSema = impl.getClangSema();
50535059

@@ -5078,51 +5084,95 @@ static clang::CXXMethodDecl *synthesizeCxxBaseGetterAccessorMethod(
50785084
newMethod->setImplicit();
50795085
newMethod->setImplicitlyInline();
50805086
newMethod->setAccess(clang::AccessSpecifier::AS_public);
5087+
if (retainOperationFn) {
5088+
// Return an FRT field at +1.
5089+
newMethod->addAttr(clang::CFReturnsRetainedAttr::CreateImplicit(clangCtx));
5090+
}
50815091

50825092
// Create a new Clang diagnostic pool to capture any diagnostics
50835093
// emitted during the construction of the method.
50845094
clang::sema::DelayedDiagnosticPool diagPool{
50855095
clangSema.DelayedDiagnostics.getCurrentPool()};
50865096
auto diagState = clangSema.DelayedDiagnostics.push(diagPool);
50875097

5098+
// Returns the expression that accesses the base field from derived type.
5099+
auto createFieldAccess = [&]() -> clang::Expr * {
5100+
auto *thisExpr = new (clangCtx)
5101+
clang::CXXThisExpr(clang::SourceLocation(), newMethod->getThisType(),
5102+
/*IsImplicit=*/false);
5103+
clang::QualType baseClassPtr = clangCtx.getRecordType(baseClass);
5104+
baseClassPtr.addConst();
5105+
baseClassPtr = clangCtx.getPointerType(baseClassPtr);
5106+
5107+
clang::CastKind Kind;
5108+
clang::CXXCastPath Path;
5109+
clangSema.CheckPointerConversion(thisExpr, baseClassPtr, Kind, Path,
5110+
/*IgnoreBaseAccess=*/false,
5111+
/*Diagnose=*/true);
5112+
auto conv = clangSema.ImpCastExprToType(thisExpr, baseClassPtr, Kind,
5113+
clang::VK_PRValue, &Path);
5114+
if (!conv.isUsable())
5115+
return nullptr;
5116+
auto memberExpr = clangSema.BuildMemberExpr(
5117+
conv.get(), /*isArrow=*/true, clang::SourceLocation(),
5118+
clang::NestedNameSpecifierLoc(), clang::SourceLocation(),
5119+
const_cast<clang::FieldDecl *>(field),
5120+
clang::DeclAccessPair::make(const_cast<clang::FieldDecl *>(field),
5121+
clang::AS_public),
5122+
/*HadMultipleCandidates=*/false,
5123+
clang::DeclarationNameInfo(field->getDeclName(),
5124+
clang::SourceLocation()),
5125+
returnType, clang::VK_LValue, clang::OK_Ordinary);
5126+
auto returnCast = clangSema.ImpCastExprToType(
5127+
memberExpr, returnType, clang::CK_LValueToRValue, clang::VK_PRValue);
5128+
if (!returnCast.isUsable())
5129+
return nullptr;
5130+
return returnCast.get();
5131+
};
5132+
5133+
llvm::SmallVector<clang::Stmt *, 2> body;
5134+
if (retainOperationFn) {
5135+
// Check if the returned value needs to be retained. This might occur if the
5136+
// field getter is returning a shared reference type using, as it needs to
5137+
// perform the retain to match the expected @owned convention.
5138+
auto *retainClangFn =
5139+
dyn_cast<clang::FunctionDecl>(retainOperationFn->getClangDecl());
5140+
if (!retainClangFn) {
5141+
return nullptr;
5142+
}
5143+
auto *fnRef = new (clangCtx) clang::DeclRefExpr(
5144+
clangCtx, const_cast<clang::FunctionDecl *>(retainClangFn), false,
5145+
retainClangFn->getType(), clang::ExprValueKind::VK_LValue,
5146+
clang::SourceLocation());
5147+
auto fieldExpr = createFieldAccess();
5148+
if (!fieldExpr)
5149+
return nullptr;
5150+
auto retainCall = clangSema.BuildResolvedCallExpr(
5151+
fnRef, const_cast<clang::FunctionDecl *>(retainClangFn),
5152+
clang::SourceLocation(), {fieldExpr}, clang::SourceLocation());
5153+
if (!retainCall.isUsable())
5154+
return nullptr;
5155+
body.push_back(retainCall.get());
5156+
}
5157+
50885158
// Construct the method's body.
5089-
auto *thisExpr = new (clangCtx) clang::CXXThisExpr(
5090-
clang::SourceLocation(), newMethod->getThisType(), /*IsImplicit=*/false);
5091-
clang::QualType baseClassPtr = clangCtx.getRecordType(baseClass);
5092-
baseClassPtr.addConst();
5093-
baseClassPtr = clangCtx.getPointerType(baseClassPtr);
5094-
5095-
clang::CastKind Kind;
5096-
clang::CXXCastPath Path;
5097-
clangSema.CheckPointerConversion(thisExpr, baseClassPtr, Kind, Path,
5098-
/*IgnoreBaseAccess=*/false,
5099-
/*Diagnose=*/true);
5100-
auto conv = clangSema.ImpCastExprToType(thisExpr, baseClassPtr, Kind,
5101-
clang::VK_PRValue, &Path);
5102-
if (!conv.isUsable())
5103-
return nullptr;
5104-
auto memberExpr = clangSema.BuildMemberExpr(
5105-
conv.get(), /*isArrow=*/true, clang::SourceLocation(),
5106-
clang::NestedNameSpecifierLoc(), clang::SourceLocation(),
5107-
const_cast<clang::FieldDecl *>(field),
5108-
clang::DeclAccessPair::make(const_cast<clang::FieldDecl *>(field),
5109-
clang::AS_public),
5110-
/*HadMultipleCandidates=*/false,
5111-
clang::DeclarationNameInfo(field->getDeclName(), clang::SourceLocation()),
5112-
returnType, clang::VK_LValue, clang::OK_Ordinary);
5113-
auto returnCast = clangSema.ImpCastExprToType(
5114-
memberExpr, returnType, clang::CK_LValueToRValue, clang::VK_PRValue);
5115-
if (!returnCast.isUsable())
5159+
auto fieldExpr = createFieldAccess();
5160+
if (!fieldExpr)
51165161
return nullptr;
51175162
auto returnStmt = clang::ReturnStmt::Create(clangCtx, clang::SourceLocation(),
5118-
returnCast.get(), nullptr);
5163+
fieldExpr, nullptr);
5164+
body.push_back(returnStmt);
51195165

51205166
// Check if there were any Clang errors during the construction
51215167
// of the method body.
51225168
clangSema.DelayedDiagnostics.popWithoutEmitting(diagState);
51235169
if (!diagPool.empty())
51245170
return nullptr;
5125-
newMethod->setBody(returnStmt);
5171+
newMethod->setBody(body.size() > 1
5172+
? clang::CompoundStmt::Create(
5173+
clangCtx, body, clang::FPOptionsOverride(),
5174+
clang::SourceLocation(), clang::SourceLocation())
5175+
: body[0]);
51265176
return newMethod;
51275177
}
51285178

@@ -5163,12 +5213,28 @@ synthesizeBaseClassFieldGetterBody(AbstractFunctionDecl *afd, void *context) {
51635213
RemoveReference,
51645214
/*forceConstQualifier=*/true);
51655215
} else if (auto *fd = dyn_cast_or_null<clang::FieldDecl>(baseClangDecl)) {
5216+
ValueDecl *retainOperationFn = nullptr;
5217+
// Check if this field getter is returning a retainable FRT.
5218+
if (getterDecl->getResultInterfaceType()->isForeignReferenceType()) {
5219+
auto retainOperation = evaluateOrDefault(
5220+
ctx.evaluator,
5221+
CustomRefCountingOperation({getterDecl->getResultInterfaceType()
5222+
->lookThroughAllOptionalTypes()
5223+
->getClassOrBoundGenericClass(),
5224+
CustomRefCountingOperationKind::retain}),
5225+
{});
5226+
if (retainOperation.kind ==
5227+
CustomRefCountingOperationResult::foundOperation) {
5228+
retainOperationFn = retainOperation.operation;
5229+
}
5230+
}
51665231
// Field getter is represented through a generated
51675232
// C++ method call that returns the value of the base field.
51685233
baseGetterCxxMethod = synthesizeCxxBaseGetterAccessorMethod(
51695234
*static_cast<ClangImporter *>(ctx.getClangModuleLoader()),
51705235
cast<clang::CXXRecordDecl>(derivedStruct->getClangDecl()),
5171-
cast<clang::CXXRecordDecl>(baseStruct->getClangDecl()), fd);
5236+
cast<clang::CXXRecordDecl>(baseStruct->getClangDecl()), fd,
5237+
retainOperationFn);
51725238
}
51735239

51745240
if (!baseGetterCxxMethod) {

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3621,6 +3621,10 @@ class CXXMethodConventions : public CFunctionTypeConventions {
36213621
// possible to make it easy for LLVM to optimize away the thunk.
36223622
return ResultConvention::Indirect;
36233623
}
3624+
if (TheDecl->hasAttr<clang::CFReturnsRetainedAttr>() &&
3625+
resultTL.getLoweredType().isForeignReferenceType()) {
3626+
return ResultConvention::Owned;
3627+
}
36243628
return CFunctionTypeConventions::getResult(resultTL);
36253629
}
36263630
static bool classof(const Conventions *C) {
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#pragma once
2+
3+
#define FRT \
4+
__attribute__((swift_attr("import_reference"))) \
5+
__attribute__((swift_attr("retain:immortal"))) \
6+
__attribute__((swift_attr("release:immortal")))
7+
8+
int &getLiveRefCountedCounter() {
9+
static int counter = 0;
10+
return counter;
11+
}
12+
13+
class RefCounted {
14+
public:
15+
RefCounted() { getLiveRefCountedCounter()++; }
16+
~RefCounted() {
17+
getLiveRefCountedCounter()--;
18+
}
19+
20+
void retain() {
21+
++refCount;
22+
}
23+
void release() {
24+
--refCount;
25+
if (refCount == 0)
26+
delete this;
27+
}
28+
29+
int testVal = 1;
30+
private:
31+
int refCount = 1;
32+
} __attribute__((swift_attr("import_reference")))
33+
__attribute__((swift_attr("retain:retainRefCounted")))
34+
__attribute__((swift_attr("release:releaseRefCounted")));
35+
36+
RefCounted * _Nonnull createRefCounted() {
37+
return new RefCounted;
38+
}
39+
40+
void retainRefCounted(RefCounted *r) {
41+
if (r)
42+
r->retain();
43+
}
44+
void releaseRefCounted(RefCounted *r) {
45+
if (r)
46+
r->release();
47+
}
48+
49+
class BaseFieldFRT {
50+
public:
51+
BaseFieldFRT(): value(new RefCounted) {}
52+
BaseFieldFRT(const BaseFieldFRT &other): value(other.value) {
53+
value->retain();
54+
}
55+
~BaseFieldFRT() {
56+
value->release();
57+
}
58+
59+
RefCounted * _Nonnull value;
60+
};
61+
62+
class DerivedFieldFRT : public BaseFieldFRT {
63+
};
64+
65+
class NonEmptyBase {
66+
public:
67+
int getY() const {
68+
return y;
69+
}
70+
private:
71+
int y = 11;
72+
};
73+
74+
class DerivedDerivedFieldFRT : public NonEmptyBase, public DerivedFieldFRT {
75+
};

test/Interop/Cxx/foreign-reference/Inputs/module.modulemap

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ module MemberInheritance {
3737
header "member-inheritance.h"
3838
requires cplusplus
3939
}
40+
41+
module DerivedFieldGetterReturnsOwnedFRT {
42+
header "derived-field-getter-returns-owned-frt.h"
43+
requires cplusplus
44+
}
45+
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: %target-swift-emit-irgen -I %S/Inputs -enable-experimental-cxx-interop %s -validate-tbd-against-ir=none -Xcc -fignore-exceptions | %FileCheck %s
2+
3+
import DerivedFieldGetterReturnsOwnedFRT
4+
5+
func testGetX() -> CInt {
6+
let derived = DerivedFieldFRT()
7+
return derived.value.testVal
8+
}
9+
10+
let _ = testGetX()
11+
12+
13+
// CHECK: define {{.*}}linkonce_odr{{.*}} ptr @{{.*}}__synthesizedBaseGetterAccessor_{{.*}}(ptr {{.*}} %[[THIS_PTR:.*]])
14+
// CHECK: %[[VALUE_PTR_PTR:.*]] = getelementptr inbounds %class.BaseFieldFRT, ptr %{{.*}}, i32 0, i32 0
15+
// CHECK: %[[VALUE_PTR:.*]] = load ptr, ptr %[[VALUE_PTR_PTR]], align 8
16+
// CHECK: call void @{{.*}}retainRefCounted{{.*}}(ptr noundef %[[VALUE_PTR]])
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: %target-swift-emit-sil -I %S/Inputs -enable-experimental-cxx-interop %s -validate-tbd-against-ir=none -Xcc -fignore-exceptions | %FileCheck %s
2+
3+
import DerivedFieldGetterReturnsOwnedFRT
4+
5+
func testGetX() -> CInt {
6+
let derived = DerivedFieldFRT()
7+
return derived.value.testVal
8+
}
9+
10+
let _ = testGetX()
11+
12+
// CHECK: function_ref @{{.*}}__synthesizedBaseGetterAccessor_{{.*}} : $@convention(cxx_method) (@in_guaranteed DerivedFieldFRT) -> @owned RefCounted
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: %target-run-simple-swift(-I %S/Inputs/ -Xfrontend -enable-experimental-cxx-interop -Xcc -fignore-exceptions -Xfrontend -disable-availability-checking -O)
2+
// RUN: %target-run-simple-swift(-I %S/Inputs/ -Xfrontend -enable-experimental-cxx-interop -Xcc -fignore-exceptions -Xfrontend -disable-availability-checking)
3+
//
4+
// REQUIRES: executable_test
5+
6+
import StdlibUnittest
7+
import DerivedFieldGetterReturnsOwnedFRT
8+
9+
var FunctionsTestSuite = TestSuite("Calling functions in base foreign reference classes")
10+
11+
FunctionsTestSuite.test("base member FRT field accessing shared reference FRT") {
12+
let refScope = {
13+
let frt: RefCounted
14+
do {
15+
let base = BaseFieldFRT()
16+
frt = base.value
17+
}
18+
expectEqual(getLiveRefCountedCounter().pointee, 1)
19+
return frt.testVal
20+
}
21+
let p = refScope()
22+
expectEqual(p, 1)
23+
expectEqual(getLiveRefCountedCounter().pointee, 0)
24+
}
25+
26+
FunctionsTestSuite.test("derived-to-base member FRT field accessing shared reference FRT") {
27+
let refScope = {
28+
let frt: RefCounted
29+
do {
30+
let derived = DerivedFieldFRT()
31+
frt = derived.value
32+
}
33+
expectEqual(getLiveRefCountedCounter().pointee, 1)
34+
return frt.testVal
35+
}
36+
let p = refScope()
37+
expectEqual(p, 1)
38+
expectEqual(getLiveRefCountedCounter().pointee, 0)
39+
}
40+
41+
FunctionsTestSuite.test("derivedDerived-to-derived-to-base member FRT field accessing shared reference FRT") {
42+
let refScope = {
43+
let frt: RefCounted
44+
do {
45+
let derived = DerivedDerivedFieldFRT()
46+
frt = derived.value
47+
}
48+
expectEqual(getLiveRefCountedCounter().pointee, 1)
49+
return frt.testVal
50+
}
51+
let p = refScope()
52+
expectEqual(p, 1)
53+
expectEqual(getLiveRefCountedCounter().pointee, 0)
54+
}
55+
56+
FunctionsTestSuite.test("base member FRT field setting shared reference FRT") {
57+
let frt = createRefCounted()
58+
var base = BaseFieldFRT()
59+
base.value = frt
60+
expectEqual(getLiveRefCountedCounter().pointee, 1)
61+
}
62+
63+
FunctionsTestSuite.test("derived member FRT field setting shared reference FRT") {
64+
let frt = createRefCounted()
65+
var base = DerivedFieldFRT()
66+
base.value = frt
67+
expectEqual(getLiveRefCountedCounter().pointee, 1)
68+
}
69+
70+
runAllTests()

0 commit comments

Comments
 (0)