Skip to content

Commit 8ebee73

Browse files
committed
[cxx-interop] Allow initializing std::function from Swift closures
This adds a Swift initializer to instantiations of `std::function` that accepts a Swift closure with `@convention(c)`. rdar://103979602
1 parent f56fa41 commit 8ebee73

File tree

5 files changed

+128
-3
lines changed

5 files changed

+128
-3
lines changed

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,3 +1019,90 @@ void swift::conformToCxxVectorIfNeeded(ClangImporter::Implementation &impl,
10191019
rawIteratorTy);
10201020
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxVector});
10211021
}
1022+
1023+
void swift::conformToCxxFunctionIfNeeded(
1024+
ClangImporter::Implementation &impl, NominalTypeDecl *decl,
1025+
const clang::CXXRecordDecl *clangDecl) {
1026+
PrettyStackTraceDecl trace("conforming to CxxFunction", decl);
1027+
1028+
assert(decl);
1029+
assert(clangDecl);
1030+
ASTContext &ctx = decl->getASTContext();
1031+
clang::ASTContext &clangCtx = impl.getClangASTContext();
1032+
clang::Sema &clangSema = impl.getClangSema();
1033+
1034+
// Only auto-conform types from the C++ standard library. Custom user types
1035+
// might have a similar interface but different semantics.
1036+
if (!isStdDecl(clangDecl, {"function"}))
1037+
return;
1038+
1039+
// There is no typealias for the argument types on the C++ side, so to
1040+
// retrieve the argument types we look at the overload of `operator()` that
1041+
// got imported into Swift.
1042+
1043+
auto callAsFunctionDecl = lookupDirectSingleWithoutExtensions<FuncDecl>(
1044+
decl, ctx.getIdentifier("callAsFunction"));
1045+
if (!callAsFunctionDecl)
1046+
return;
1047+
1048+
auto operatorCallDecl = dyn_cast_or_null<clang::CXXMethodDecl>(
1049+
callAsFunctionDecl->getClangDecl());
1050+
if (!operatorCallDecl)
1051+
return;
1052+
1053+
std::vector<clang::QualType> operatorCallParamTypes;
1054+
llvm::transform(
1055+
operatorCallDecl->parameters(),
1056+
std::back_inserter(operatorCallParamTypes),
1057+
[](const clang::ParmVarDecl *paramDecl) { return paramDecl->getType(); });
1058+
1059+
auto funcPointerType = clangCtx.getPointerType(clangCtx.getFunctionType(
1060+
operatorCallDecl->getReturnType(), operatorCallParamTypes,
1061+
clang::FunctionProtoType::ExtProtoInfo()));
1062+
1063+
// Create a fake variable with a function type that matches the type of
1064+
// `operator()`.
1065+
auto fakeFuncPointerVarDecl = clang::VarDecl::Create(
1066+
clangCtx, /*DC*/ clangCtx.getTranslationUnitDecl(),
1067+
clang::SourceLocation(), clang::SourceLocation(), /*Id*/ nullptr,
1068+
funcPointerType, clangCtx.getTrivialTypeSourceInfo(funcPointerType),
1069+
clang::StorageClass::SC_None);
1070+
auto fakeFuncPointerRefExpr = new (clangCtx) clang::DeclRefExpr(
1071+
clangCtx, fakeFuncPointerVarDecl, false, funcPointerType,
1072+
clang::ExprValueKind::VK_LValue, clang::SourceLocation());
1073+
1074+
auto clangDeclTyInfo = clangCtx.getTrivialTypeSourceInfo(
1075+
clang::QualType(clangDecl->getTypeForDecl(), 0));
1076+
SmallVector<clang::Expr *, 1> constructExprArgs = {fakeFuncPointerRefExpr};
1077+
1078+
// Instantiate the templated constructor that would accept this fake variable.
1079+
auto constructExprResult = clangSema.BuildCXXTypeConstructExpr(
1080+
clangDeclTyInfo, clangDecl->getLocation(), constructExprArgs,
1081+
clangDecl->getLocation(), /*ListInitialization*/ false);
1082+
if (!constructExprResult.isUsable())
1083+
return;
1084+
1085+
auto castExpr = dyn_cast_or_null<clang::CastExpr>(constructExprResult.get());
1086+
if (!castExpr)
1087+
return;
1088+
1089+
auto bindTempExpr =
1090+
dyn_cast_or_null<clang::CXXBindTemporaryExpr>(castExpr->getSubExpr());
1091+
if (!bindTempExpr)
1092+
return;
1093+
1094+
auto constructExpr =
1095+
dyn_cast_or_null<clang::CXXConstructExpr>(bindTempExpr->getSubExpr());
1096+
if (!constructExpr)
1097+
return;
1098+
1099+
auto constructorDecl = constructExpr->getConstructor();
1100+
1101+
auto importedConstructor =
1102+
impl.importDecl(constructorDecl, impl.CurrentVersion);
1103+
if (!importedConstructor)
1104+
return;
1105+
decl->addMember(importedConstructor);
1106+
1107+
// TODO: actually conform to some form of CxxFunction protocol
1108+
}

lib/ClangImporter/ClangDerivedConformances.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ void conformToCxxVectorIfNeeded(ClangImporter::Implementation &impl,
7171
NominalTypeDecl *decl,
7272
const clang::CXXRecordDecl *clangDecl);
7373

74+
/// If the decl is an instantiation of C++ `std::function`, synthesize a
75+
/// conformance to CxxFunction, which is defined in the Cxx module.
76+
void conformToCxxFunctionIfNeeded(ClangImporter::Implementation &impl,
77+
NominalTypeDecl *decl,
78+
const clang::CXXRecordDecl *clangDecl);
79+
7480
} // namespace swift
7581

7682
#endif // SWIFT_CLANG_DERIVED_CONFORMANCES_H

lib/ClangImporter/ImportDecl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2815,6 +2815,7 @@ namespace {
28152815
conformToCxxPairIfNeeded(Impl, nominalDecl, decl);
28162816
conformToCxxOptionalIfNeeded(Impl, nominalDecl, decl);
28172817
conformToCxxVectorIfNeeded(Impl, nominalDecl, decl);
2818+
conformToCxxFunctionIfNeeded(Impl, nominalDecl, decl);
28182819
}
28192820
}
28202821

test/Interop/Cxx/stdlib/Inputs/std-function.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
#define TEST_INTEROP_CXX_STDLIB_INPUTS_STD_FUNCTION_H
33

44
#include <functional>
5+
#include <string>
56

67
using FunctionIntToInt = std::function<int(int)>;
8+
using FunctionStringToString = std::function<std::string(const std::string&)>;
79

810
inline FunctionIntToInt getIdentityFunction() {
911
return [](int x) { return x; };
@@ -13,4 +15,8 @@ inline bool isEmptyFunction(FunctionIntToInt f) { return !(bool)f; }
1315

1416
inline int invokeFunction(FunctionIntToInt f, int x) { return f(x); }
1517

18+
std::string invokeFunctionTwice(FunctionStringToString f, std::string s) {
19+
return f(f(s));
20+
}
21+
1622
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_STD_FUNCTION_H

test/Interop/Cxx/stdlib/use-std-function.swift

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,50 @@
99

1010
import StdlibUnittest
1111
import StdFunction
12+
import CxxStdlib
1213

1314
var StdFunctionTestSuite = TestSuite("StdFunction")
1415

15-
StdFunctionTestSuite.test("init empty") {
16+
StdFunctionTestSuite.test("FunctionIntToInt init empty") {
1617
let f = FunctionIntToInt()
1718
expectTrue(isEmptyFunction(f))
1819

1920
let copied = f
2021
expectTrue(isEmptyFunction(copied))
2122
}
2223

23-
StdFunctionTestSuite.test("call") {
24+
StdFunctionTestSuite.test("FunctionIntToInt call") {
2425
let f = getIdentityFunction()
2526
expectEqual(123, f(123))
2627
}
2728

28-
StdFunctionTestSuite.test("retrieve and pass back as parameter") {
29+
StdFunctionTestSuite.test("FunctionIntToInt retrieve and pass back as parameter") {
2930
let res = invokeFunction(getIdentityFunction(), 456)
3031
expectEqual(456, res)
3132
}
3233

34+
StdFunctionTestSuite.test("FunctionIntToInt init from closure and call") {
35+
let cClosure: @convention(c) (Int32) -> Int32 = { $0 + 1 }
36+
let f = FunctionIntToInt(cClosure)
37+
expectEqual(1, f(0))
38+
expectEqual(124, f(123))
39+
expectEqual(0, f(-1))
40+
41+
let f2 = FunctionIntToInt({ $0 * 2 })
42+
expectEqual(0, f2(0))
43+
expectEqual(246, f2(123))
44+
}
45+
46+
StdFunctionTestSuite.test("FunctionIntToInt init from closure and pass as parameter") {
47+
let res = invokeFunction(.init({ $0 * 2 }), 111)
48+
expectEqual(222, res)
49+
}
50+
51+
// FIXME: assertion for address-only closure params
52+
//StdFunctionTestSuite.test("FunctionStringToString init from closure and pass as parameter") {
53+
// let res = invokeFunctionTwice(.init({ $0 + std.string("abc") }),
54+
// std.string("prefix"))
55+
// expectEqual(std.string("prefixabcabc"), res)
56+
//}
57+
3358
runAllTests()

0 commit comments

Comments
 (0)