Skip to content

Commit ba5b1ba

Browse files
committed
[cxx-interop] Use a synthesized C++ method when invoking a base method from a derived class synthesized method
The use of a synthesized C++ method allows us to avoid making a copy of self when invoking the base method from Swift
1 parent c92f6af commit ba5b1ba

File tree

9 files changed

+279
-34
lines changed

9 files changed

+279
-34
lines changed

include/swift/AST/DiagnosticsClangImporter.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ ERROR(move_only_requires_move_only,none,
247247
"use of noncopyable C++ type '%0' requires -enable-experimental-move-only",
248248
(StringRef))
249249

250+
ERROR(failed_base_method_call_synthesis,none,
251+
"failed to synthesize call to the base method %0 of type %0",
252+
(ValueDecl *, ValueDecl *))
253+
250254
NOTE(unsupported_builtin_type, none, "built-in type '%0' not supported", (StringRef))
251255
NOTE(record_field_not_imported, none, "field %0 unavailable (cannot import)", (const clang::NamedDecl*))
252256
NOTE(invoked_func_not_imported, none, "function %0 unavailable (cannot import)", (const clang::NamedDecl*))

lib/ClangImporter/ClangImporter.cpp

Lines changed: 159 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
#include "clang/Parse/Parser.h"
6565
#include "clang/Rewrite/Frontend/FrontendActions.h"
6666
#include "clang/Rewrite/Frontend/Rewriters.h"
67+
#include "clang/Sema/DelayedDiagnostic.h"
6768
#include "clang/Sema/Lookup.h"
6869
#include "clang/Sema/Sema.h"
6970
#include "clang/Serialization/ASTReader.h"
@@ -4762,27 +4763,165 @@ MemberRefExpr *getSelfInteropStaticCast(FuncDecl *funcDecl,
47624763
return pointeePropertyRefExpr;
47634764
}
47644765

4765-
// For const methods generate the following:
4766-
// %0 = __swift_interopStaticCast<Base>(self)
4767-
// return %0.fn(args...)
4768-
// For mutating methods we have to pass self as a pointer:
4769-
// %0 = Builtin.addressof(&self)
4770-
// %1 = Builtin.reinterpretCast<UnsafeMutablePointer<Derived>>(%0)
4771-
// %2 = __swift_interopStaticCast<UnsafeMutablePointer<Base>?>(%1)
4772-
// %3 = %2!
4773-
// %4 = %3.pointee
4774-
// return %4.fn(args...)
4766+
// Synthesize a C++ method that invokes the method from the base
4767+
// class. This lets Clang take care of the cast from the derived class
4768+
// to the base class during the invocation of the method.
4769+
static clang::CXXMethodDecl *synthesizeCxxBaseMethod(
4770+
ClangImporter &impl, const clang::CXXRecordDecl *derivedClass,
4771+
const clang::CXXRecordDecl *baseClass, const clang::CXXMethodDecl *method) {
4772+
auto &clangCtx = impl.getClangASTContext();
4773+
auto &clangSema = impl.getClangSema();
4774+
4775+
// Create a new method in the derived class that calls the base method.
4776+
auto name = method->getNameInfo().getName();
4777+
if (name.isIdentifier()) {
4778+
std::string newName;
4779+
llvm::raw_string_ostream os(newName);
4780+
os << "__synthesizedBaseCall_" << name.getAsIdentifierInfo()->getName();
4781+
name = clang::DeclarationName(
4782+
&impl.getClangPreprocessor().getIdentifierTable().get(os.str()));
4783+
}
4784+
auto newMethod = clang::CXXMethodDecl::Create(
4785+
clangCtx, const_cast<clang::CXXRecordDecl *>(derivedClass),
4786+
method->getSourceRange().getBegin(),
4787+
clang::DeclarationNameInfo(name, clang::SourceLocation()),
4788+
method->getType(), method->getTypeSourceInfo(), method->getStorageClass(),
4789+
method->UsesFPIntrin(), /*isInline=*/true, method->getConstexprKind(),
4790+
method->getSourceRange().getEnd());
4791+
newMethod->setImplicit();
4792+
newMethod->setImplicitlyInline();
4793+
newMethod->setAccess(clang::AccessSpecifier::AS_public);
4794+
4795+
llvm::SmallVector<clang::ParmVarDecl *, 4> params;
4796+
for (size_t i = 0; i < method->getNumParams(); ++i) {
4797+
const auto &param = *method->getParamDecl(i);
4798+
params.push_back(clang::ParmVarDecl::Create(
4799+
clangCtx, newMethod, param.getSourceRange().getBegin(),
4800+
param.getLocation(), param.getIdentifier(), param.getType(),
4801+
param.getTypeSourceInfo(), param.getStorageClass(),
4802+
/*DefExpr=*/nullptr));
4803+
}
4804+
newMethod->setParams(params);
4805+
4806+
// Create a new Clang diagnostic pool to capture any diagnostics
4807+
// emitted during the construction of the method.
4808+
clang::sema::DelayedDiagnosticPool diagPool{
4809+
clangSema.DelayedDiagnostics.getCurrentPool()};
4810+
auto diagState = clangSema.DelayedDiagnostics.push(diagPool);
4811+
4812+
// Construct the method's body.
4813+
auto *thisExpr = new (clangCtx) clang::CXXThisExpr(
4814+
clang::SourceLocation(), newMethod->getThisType(), /*IsImplicit=*/false);
4815+
auto memberExpr = clangSema.BuildMemberExpr(
4816+
thisExpr, /*isArrow=*/true, clang::SourceLocation(),
4817+
clang::NestedNameSpecifierLoc(), clang::SourceLocation(),
4818+
const_cast<clang::CXXMethodDecl *>(method),
4819+
clang::DeclAccessPair::make(const_cast<clang::CXXMethodDecl *>(method),
4820+
clang::AS_public),
4821+
/*HadMultipleCandidates=*/false, method->getNameInfo(),
4822+
clangCtx.BoundMemberTy, clang::VK_PRValue, clang::OK_Ordinary);
4823+
llvm::SmallVector<clang::Expr *, 4> args;
4824+
for (size_t i = 0; i < newMethod->getNumParams(); ++i) {
4825+
auto *param = newMethod->getParamDecl(i);
4826+
auto type = param->getType();
4827+
if (type->isReferenceType())
4828+
type = type->getPointeeType();
4829+
args.push_back(new (clangCtx) clang::DeclRefExpr(
4830+
clangCtx, param, false, type, clang::ExprValueKind::VK_LValue,
4831+
clang::SourceLocation()));
4832+
}
4833+
auto memberCall = clangSema.BuildCallToMemberFunction(
4834+
nullptr, memberExpr, clang::SourceLocation(), args,
4835+
clang::SourceLocation());
4836+
if (!memberCall.isUsable())
4837+
return nullptr;
4838+
auto returnStmt = clang::ReturnStmt::Create(clangCtx, clang::SourceLocation(),
4839+
memberCall.get(), nullptr);
4840+
4841+
// Check if there were any Clang errors during the construction
4842+
// of the method body.
4843+
clangSema.DelayedDiagnostics.popWithoutEmitting(diagState);
4844+
if (!diagPool.empty())
4845+
return nullptr;
4846+
4847+
newMethod->setBody(returnStmt);
4848+
return newMethod;
4849+
}
4850+
4851+
// Find the base C++ method called by the base function we want to synthesize
4852+
// the derived thunk for.
4853+
// The base C++ method is either the original C++ method that corresponds
4854+
// to the imported base member, or it's the synthesized C++ method thunk
4855+
// used in another synthesized derived thunk that acts as a base member here.
4856+
const clang::CXXMethodDecl *getCalledBaseCxxMethod(FuncDecl *baseMember) {
4857+
if (baseMember->getClangDecl())
4858+
return dyn_cast<clang::CXXMethodDecl>(baseMember->getClangDecl());
4859+
// Another synthesized derived thunk is used as a base member here,
4860+
// so extract its synthesized C++ method.
4861+
auto body = baseMember->getBody();
4862+
if (body->getElements().empty())
4863+
return nullptr;
4864+
ReturnStmt *returnStmt = dyn_cast_or_null<ReturnStmt>(
4865+
body->getElements().front().dyn_cast<Stmt *>());
4866+
if (!returnStmt)
4867+
return nullptr;
4868+
auto *callExpr = dyn_cast<CallExpr>(returnStmt->getResult());
4869+
if (!callExpr)
4870+
return nullptr;
4871+
auto *cv = callExpr->getCalledValue();
4872+
if (!cv)
4873+
return nullptr;
4874+
if (!cv->getClangDecl())
4875+
return nullptr;
4876+
return dyn_cast<clang::CXXMethodDecl>(cv->getClangDecl());
4877+
}
4878+
4879+
// Construct a Swift method that represents the synthesized C++ method
4880+
// that invokes the base C++ method.
4881+
FuncDecl *synthesizeBaseFunctionDeclCall(ClangImporter &impl, ASTContext &ctx,
4882+
NominalTypeDecl *derivedStruct,
4883+
NominalTypeDecl *baseStruct,
4884+
FuncDecl *baseMember) {
4885+
auto *cxxMethod = getCalledBaseCxxMethod(baseMember);
4886+
if (!cxxMethod)
4887+
return nullptr;
4888+
auto *newClangMethod = synthesizeCxxBaseMethod(
4889+
impl, cast<clang::CXXRecordDecl>(derivedStruct->getClangDecl()),
4890+
cast<clang::CXXRecordDecl>(baseStruct->getClangDecl()), cxxMethod);
4891+
if (!newClangMethod)
4892+
return nullptr;
4893+
return cast<FuncDecl>(
4894+
ctx.getClangModuleLoader()->importDeclDirectly(newClangMethod));
4895+
}
4896+
4897+
// Generates the body of a derived method, that invokes the base
4898+
// method.
4899+
// The method's body takes the following form:
4900+
// return self.__synthesizedBaseCall_fn(args...)
47754901
static std::pair<BraceStmt *, bool>
47764902
synthesizeBaseClassMethodBody(AbstractFunctionDecl *afd, void *context) {
4903+
47774904
ASTContext &ctx = afd->getASTContext();
47784905

47794906
auto funcDecl = cast<FuncDecl>(afd);
47804907
auto derivedStruct =
47814908
cast<NominalTypeDecl>(funcDecl->getDeclContext()->getAsDecl());
47824909
auto baseMember = static_cast<FuncDecl *>(context);
4783-
auto baseStruct = cast<NominalTypeDecl>(baseMember->getDeclContext()->getAsDecl());
4910+
auto baseStruct =
4911+
cast<NominalTypeDecl>(baseMember->getDeclContext()->getAsDecl());
47844912
auto baseType = baseStruct->getDeclaredType();
47854913

4914+
auto forwardedFunc = synthesizeBaseFunctionDeclCall(
4915+
*static_cast<ClangImporter *>(ctx.getClangModuleLoader()), ctx,
4916+
derivedStruct, baseStruct, baseMember);
4917+
if (!forwardedFunc) {
4918+
ctx.Diags.diagnose(SourceLoc(), diag::failed_base_method_call_synthesis,
4919+
funcDecl, baseStruct);
4920+
auto body = BraceStmt::create(ctx, SourceLoc(), {}, SourceLoc(),
4921+
/*implicit=*/true);
4922+
return {body, /*isTypeChecked=*/true};
4923+
}
4924+
47864925
SmallVector<Expr *, 8> forwardingParams;
47874926
for (auto param : *funcDecl->getParameters()) {
47884927
auto paramRefExpr = new (ctx) DeclRefExpr(param, DeclNameLoc(),
@@ -4791,34 +4930,25 @@ synthesizeBaseClassMethodBody(AbstractFunctionDecl *afd, void *context) {
47914930
forwardingParams.push_back(paramRefExpr);
47924931
}
47934932

4794-
Argument casted = [&]() {
4795-
if (funcDecl->isMutating()) {
4796-
return Argument::implicitInOut(
4797-
ctx, getSelfInteropStaticCast(funcDecl, baseStruct, derivedStruct));
4798-
}
4933+
Argument selfArg = [&]() {
47994934
auto *selfDecl = funcDecl->getImplicitSelfDecl();
48004935
auto selfExpr = new (ctx) DeclRefExpr(selfDecl, DeclNameLoc(),
48014936
/*implicit*/ true);
4937+
if (funcDecl->isMutating()) {
4938+
selfExpr->setType(LValueType::get(selfDecl->getInterfaceType()));
4939+
return Argument::implicitInOut(ctx, selfExpr);
4940+
}
48024941
selfExpr->setType(selfDecl->getTypeInContext());
4803-
4804-
auto staticCastRefExpr = getInteropStaticCastDeclRefExpr(
4805-
ctx, baseStruct->getClangDecl()->getOwningModule(), baseType,
4806-
derivedStruct->getDeclaredType());
4807-
4808-
auto *argList = ArgumentList::forImplicitUnlabeled(ctx, {selfExpr});
4809-
auto castedCall = CallExpr::createImplicit(ctx, staticCastRefExpr, argList);
4810-
castedCall->setType(baseType);
4811-
castedCall->setThrows(false);
4812-
return Argument::unlabeled(castedCall);
4942+
return Argument::unlabeled(selfExpr);
48134943
}();
48144944

48154945
auto *baseMemberExpr =
4816-
new (ctx) DeclRefExpr(ConcreteDeclRef(baseMember), DeclNameLoc(),
4946+
new (ctx) DeclRefExpr(ConcreteDeclRef(forwardedFunc), DeclNameLoc(),
48174947
/*Implicit=*/true);
4818-
baseMemberExpr->setType(baseMember->getInterfaceType());
4948+
baseMemberExpr->setType(forwardedFunc->getInterfaceType());
48194949

48204950
auto baseMemberDotCallExpr =
4821-
DotSyntaxCallExpr::create(ctx, baseMemberExpr, SourceLoc(), casted);
4951+
DotSyntaxCallExpr::create(ctx, baseMemberExpr, SourceLoc(), selfArg);
48224952
baseMemberDotCallExpr->setType(baseMember->getMethodInterfaceType());
48234953
baseMemberDotCallExpr->setThrows(false);
48244954

test/Interop/Cxx/class/inheritance/Inputs/functions.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,48 @@ struct DerivedFromEmptyBaseClass : EmptyBaseClass {
110110
int a = 42;
111111
int b = 42;
112112
};
113+
114+
int &getCopyCounter() {
115+
static int copyCounter = 0;
116+
return copyCounter;
117+
}
118+
119+
class CopyTrackedBaseClass {
120+
public:
121+
CopyTrackedBaseClass(int x) : x(x) {}
122+
CopyTrackedBaseClass(const CopyTrackedBaseClass &other) : x(other.x) {
123+
++getCopyCounter();
124+
}
125+
126+
int getX() const {
127+
return x;
128+
}
129+
int getXMut() {
130+
return x;
131+
}
132+
private:
133+
int x;
134+
};
135+
136+
class CopyTrackedDerivedClass: public CopyTrackedBaseClass {
137+
public:
138+
CopyTrackedDerivedClass(int x) : CopyTrackedBaseClass(x) {}
139+
140+
int getDerivedX() const {
141+
return getX();
142+
}
143+
};
144+
145+
class NonEmptyBase {
146+
public:
147+
int getY() const {
148+
return y;
149+
}
150+
private:
151+
int y = 11;
152+
};
153+
154+
class CopyTrackedDerivedDerivedClass: public NonEmptyBase, public CopyTrackedDerivedClass {
155+
public:
156+
CopyTrackedDerivedDerivedClass(int x) : CopyTrackedDerivedClass(x) {}
157+
};
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %target-swift-emit-irgen -I %S/Inputs -enable-experimental-cxx-interop %s -validate-tbd-against-ir=none -Xcc -fignore-exceptions | %FileCheck %s
2+
3+
import Functions
4+
5+
func testGetX() -> CInt {
6+
let derivedDerived = CopyTrackedDerivedDerivedClass(42)
7+
return derivedDerived.getX()
8+
}
9+
10+
let _ = testGetX()
11+
12+
// CHECK: define {{.*}} swiftcc i32 @"$sSo018CopyTrackedDerivedC5ClassV4getXs5Int32VyF"(ptr noalias swiftself dereferenceable(8) %[[SELF_PTR:.*]])
13+
// CHECK: = call i32 @[[SYNTH_METHOD:.*]](ptr %[[SELF_PTR]])
14+
15+
// CHECK: define {{.*}}linkonce_odr{{.*}} i32 @[[SYNTH_METHOD]](ptr {{.*}} %[[THIS_PTR:.*]])
16+
// CHECK: %[[ADD_PTR:.*]] = getelementptr inbounds i8, ptr %{{.*}}, i64 4
17+
// CHECK: call noundef i32 @{{.*}}(ptr {{.*}} %[[ADD_PTR]])
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %target-swift-emit-sil -I %S/Inputs -enable-experimental-cxx-interop %s -validate-tbd-against-ir=none | %FileCheck %s
2+
3+
import Functions
4+
5+
func testGetX() -> CInt {
6+
let derived = CopyTrackedDerivedClass(42)
7+
return derived.getX()
8+
}
9+
10+
let _ = testGetX()
11+
12+
// CHECK: sil shared @$sSo23CopyTrackedDerivedClassV4getXs5Int32VyF : $@convention(method) (@in_guaranteed CopyTrackedDerivedClass) -> Int32
13+
// CHECK: {{.*}}(%[[SELF_VAL:.*]] : $*CopyTrackedDerivedClass):
14+
// CHECK: function_ref @{{.*}}__synthesizedBaseCall_{{.*}} : $@convention(cxx_method) (@in_guaranteed CopyTrackedDerivedClass) -> Int32
15+
// CHECK-NEXT: apply %{{.*}}(%[[SELF_VAL]])

test/Interop/Cxx/class/inheritance/functions.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,26 @@ FunctionsTestSuite.test("non-empty derived from empty class") {
8282
expectEqual(derived.b, 42)
8383
}
8484

85+
FunctionsTestSuite.test("base member calls do not require copying") {
86+
let derived = CopyTrackedDerivedClass(42)
87+
var copyCounter = getCopyCounter().pointee
88+
expectEqual(derived.getX(), 42)
89+
expectEqual(copyCounter, getCopyCounter().pointee)
90+
expectEqual(derived.getDerivedX(), 42)
91+
expectEqual(copyCounter, getCopyCounter().pointee)
92+
93+
let derivedDerived = CopyTrackedDerivedDerivedClass(-5)
94+
copyCounter = getCopyCounter().pointee
95+
expectEqual(derivedDerived.getX(), -5)
96+
expectEqual(derivedDerived.getY(), 11)
97+
expectEqual(copyCounter, getCopyCounter().pointee)
98+
}
99+
100+
FunctionsTestSuite.test("mutating base member calls do not require copying") {
101+
var derived = CopyTrackedDerivedClass(42)
102+
var copyCounter = getCopyCounter().pointee
103+
expectEqual(derived.getXMut(), 42)
104+
expectEqual(copyCounter, getCopyCounter().pointee)
105+
}
106+
85107
runAllTests()

test/Interop/Cxx/implementation-only-imports/import-implementation-only-cxx-interop-module-without-interop.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44

55
// RUN: %target-swift-frontend %s -typecheck -module-name TestMod -I %t -I %S/Inputs
66

7-
// Check that we have used something from CxxShim in 'UseModuleAImplOnly'
8-
// RUN: %target-swift-frontend %S/Inputs/use-module-a-impl-only.swift -I %S/Inputs/ -module-name UseModuleAImplOnly -emit-module -emit-module-path %t/UseModuleAImplOnly.swiftmodule -cxx-interoperability-mode=default -emit-sil -o - -enable-library-evolution | %FileCheck %s
9-
// CHECK: __swift_interopStaticCast
10-
117
import UseModuleAImplOnly
128

139
public func testCallsAPI() {

test/Interop/Cxx/stdlib/Inputs/std-vector.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,8 @@ inline std::string takesVectorOfString(const VectorOfString &v) {
1313
return v.front();
1414
}
1515

16-
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_STD_VECTOR_H
16+
class VectorSubclass: public Vector {
17+
public:
18+
};
19+
20+
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_STD_VECTOR_H

test/Interop/Cxx/stdlib/use-std-vector.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,16 @@ StdVectorTestSuite.test("VectorOfString.map") {
119119
expectEqual(a, [3, 1, 2])
120120
}
121121

122+
StdVectorTestSuite.test("VectorOfInt subclass for loop") {
123+
var v = VectorSubclass()
124+
v.push_back(1)
125+
126+
var count: CInt = 1
127+
for e in v {
128+
expectEqual(e, count)
129+
count += 1
130+
}
131+
expectEqual(count, 2)
132+
}
133+
122134
runAllTests()

0 commit comments

Comments
 (0)