Skip to content

Commit b2fe628

Browse files
authored
Merge pull request #81016 from j-hui/swift-function-as-template-arg
[cxx-interop] Allow C++ function templates to be instantiated with Swift closures
2 parents 7768aa3 + fee1dd3 commit b2fe628

File tree

6 files changed

+316
-48
lines changed

6 files changed

+316
-48
lines changed

lib/AST/ASTContext.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6585,13 +6585,16 @@ const clang::Type *
65856585
ASTContext::getClangFunctionType(ArrayRef<AnyFunctionType::Param> params,
65866586
Type resultTy,
65876587
FunctionTypeRepresentation trueRep) {
6588-
return getClangTypeConverter().getFunctionType(params, resultTy, trueRep);
6588+
return getClangTypeConverter().getFunctionType</*templateArgument=*/false>(
6589+
params, resultTy, trueRep);
65896590
}
65906591

65916592
const clang::Type *ASTContext::getCanonicalClangFunctionType(
65926593
ArrayRef<SILParameterInfo> params, std::optional<SILResultInfo> result,
65936594
SILFunctionType::Representation trueRep) {
6594-
auto *ty = getClangTypeConverter().getFunctionType(params, result, trueRep);
6595+
auto *ty =
6596+
getClangTypeConverter().getFunctionType</*templateArgument=*/false>(
6597+
params, result, trueRep);
65956598
return ty ? ty->getCanonicalTypeInternal().getTypePtr() : nullptr;
65966599
}
65976600

lib/AST/ClangTypeConverter.cpp

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "clang/Basic/TargetInfo.h"
4141
#include "clang/Sema/Sema.h"
4242

43+
#include "llvm/ADT/STLExtras.h"
4344
#include "llvm/ADT/StringSwitch.h"
4445
#include "llvm/Support/Compiler.h"
4546

@@ -122,19 +123,22 @@ const clang::ASTContext &clangCtx) {
122123

123124
} // end anonymous namespace
124125

125-
const clang::Type *ClangTypeConverter::getFunctionType(
126-
ArrayRef<AnyFunctionType::Param> params, Type resultTy,
127-
AnyFunctionType::Representation repr) {
128-
129-
auto resultClangTy = convert(resultTy);
126+
template <bool templateArgument>
127+
const clang::Type *
128+
ClangTypeConverter::getFunctionType(ArrayRef<AnyFunctionType::Param> params,
129+
Type resultTy,
130+
AnyFunctionType::Representation repr) {
131+
auto resultClangTy =
132+
templateArgument ? convertTemplateArgument(resultTy) : convert(resultTy);
130133
if (resultClangTy.isNull())
131134
return nullptr;
132135

133136
SmallVector<clang::FunctionProtoType::ExtParameterInfo, 4> extParamInfos;
134137
SmallVector<clang::QualType, 4> paramsClangTy;
135138
bool someParamIsConsumed = false;
136139
for (auto p : params) {
137-
auto pc = convert(p.getPlainType());
140+
auto pc = templateArgument ? convertTemplateArgument(p.getPlainType())
141+
: convert(p.getPlainType());
138142
if (pc.isNull())
139143
return nullptr;
140144
clang::FunctionProtoType::ExtParameterInfo extParamInfo;
@@ -165,16 +169,21 @@ const clang::Type *ClangTypeConverter::getFunctionType(
165169
llvm_unreachable("invalid representation");
166170
}
167171

172+
template <bool templateArgument>
168173
const clang::Type *
169174
ClangTypeConverter::getFunctionType(ArrayRef<SILParameterInfo> params,
170175
std::optional<SILResultInfo> result,
171176
SILFunctionType::Representation repr) {
172-
173-
// Using the interface type is sufficient as type parameters get mapped to
174-
// `id`, since ObjC lightweight generics use type erasure. (See also: SE-0057)
175-
auto resultClangTy = result.has_value()
176-
? convert(result.value().getInterfaceType())
177-
: ClangASTContext.VoidTy;
177+
clang::QualType resultClangTy = ClangASTContext.VoidTy;
178+
if (result) {
179+
// Using the interface type is sufficient as type parameters get mapped to
180+
// `id`, since ObjC lightweight generics use type erasure.
181+
//
182+
// (See also: SE-0057)
183+
auto interfaceType = result->getInterfaceType();
184+
resultClangTy = templateArgument ? convertTemplateArgument(interfaceType)
185+
: convert(interfaceType);
186+
}
178187

179188
if (resultClangTy.isNull())
180189
return nullptr;
@@ -183,7 +192,8 @@ ClangTypeConverter::getFunctionType(ArrayRef<SILParameterInfo> params,
183192
SmallVector<clang::QualType, 4> paramsClangTy;
184193
bool someParamIsConsumed = false;
185194
for (auto &p : params) {
186-
auto pc = convert(p.getInterfaceType());
195+
auto pc = templateArgument ? convertTemplateArgument(p.getInterfaceType())
196+
: convert(p.getInterfaceType());
187197
if (pc.isNull())
188198
return nullptr;
189199
clang::FunctionProtoType::ExtParameterInfo extParamInfo;
@@ -565,18 +575,18 @@ ClangTypeConverter::visitBoundGenericType(BoundGenericType *type) {
565575
}
566576

567577
if (auto kind = classifyPointer(type))
568-
return convertPointerType(argType, kind.value(),
569-
/*templateArgument=*/false);
578+
return convertPointerType</*templateArgument=*/false>(argType,
579+
kind.value());
570580

571581
if (auto width = classifySIMD(type))
572-
return convertSIMDType(argType, width.value(), /*templateArgument=*/false);
582+
return convertSIMDType</*templateArgument=*/false>(argType, width.value());
573583

574584
return clang::QualType();
575585
}
576586

587+
template <bool templateArgument>
577588
clang::QualType ClangTypeConverter::convertSIMDType(CanType scalarType,
578-
unsigned width,
579-
bool templateArgument) {
589+
unsigned width) {
580590
clang::QualType scalarTy = templateArgument
581591
? convertTemplateArgument(scalarType)
582592
: convert(scalarType);
@@ -588,9 +598,9 @@ clang::QualType ClangTypeConverter::convertSIMDType(CanType scalarType,
588598
return vectorTy;
589599
}
590600

601+
template <bool templateArgument>
591602
clang::QualType ClangTypeConverter::convertPointerType(CanType pointeeType,
592-
PointerKind kind,
593-
bool templateArgument) {
603+
PointerKind kind) {
594604
switch (kind) {
595605
case PointerKind::Unmanaged:
596606
return templateArgument ? clang::QualType() : convert(pointeeType);
@@ -651,6 +661,7 @@ clang::QualType ClangTypeConverter::visitEnumType(EnumType *type) {
651661
return convert(type->getDecl()->getRawType());
652662
}
653663

664+
template <bool templateArgument>
654665
clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type) {
655666
const clang::Type *clangTy = nullptr;
656667
auto repr = type->getRepresentation();
@@ -665,12 +676,15 @@ clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type) {
665676
auto newRepr = (repr == FunctionTypeRepresentation::Swift
666677
? FunctionTypeRepresentation::Block
667678
: repr);
668-
clangTy = getFunctionType(type->getParams(), type->getResult(), newRepr);
679+
clangTy = getFunctionType<templateArgument>(type->getParams(),
680+
type->getResult(), newRepr);
669681
}
670682
return clang::QualType(clangTy, 0);
671683
}
672684

673-
clang::QualType ClangTypeConverter::visitSILFunctionType(SILFunctionType *type) {
685+
template <bool templateArgument>
686+
clang::QualType
687+
ClangTypeConverter::visitSILFunctionType(SILFunctionType *type) {
674688
const clang::Type *clangTy = nullptr;
675689
auto repr = type->getRepresentation();
676690
bool useClangTypes = type->getASTContext().LangOpts.UseClangFunctionTypes;
@@ -688,7 +702,8 @@ clang::QualType ClangTypeConverter::visitSILFunctionType(SILFunctionType *type)
688702
auto optionalResult = results.empty()
689703
? std::nullopt
690704
: std::optional<SILResultInfo>(results[0]);
691-
clangTy = getFunctionType(type->getParameters(), optionalResult, newRepr);
705+
clangTy = getFunctionType<templateArgument>(type->getParameters(),
706+
optionalResult, newRepr);
692707
}
693708
return clang::QualType(clangTy, 0);
694709
}
@@ -933,6 +948,13 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
933948
if (auto floatType = type->getAs<BuiltinFloatType>())
934949
return withCache([&]() { return visitBuiltinFloatType(floatType); });
935950

951+
if (auto tupleType = type->getAs<TupleType>()) {
952+
// We do not call visitTupleType() because we cannot yet handle tuples with
953+
// a non-zero number of elements.
954+
if (tupleType->getNumElements() == 0)
955+
return ClangASTContext.VoidTy;
956+
}
957+
936958
if (auto structType = type->getAs<StructType>()) {
937959
// Swift structs are not supported in general, but some foreign types are
938960
// imported as Swift structs. We reverse that mapping here.
@@ -953,8 +975,6 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
953975
return withCache([&]() { return reverseBuiltinTypeMapping(structType); });
954976
}
955977

956-
// TODO: function pointers are not yet supported, but they should be.
957-
958978
if (auto boundGenericType = type->getAs<BoundGenericType>()) {
959979
if (boundGenericType->getGenericArgs().size() != 1)
960980
// Must've got something other than a T?, *Pointer<T>, or SIMD*<T>
@@ -968,8 +988,8 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
968988
auto pointeeType = argType->getAs<BoundGenericType>()
969989
->getGenericArgs()[0]
970990
->getCanonicalType();
971-
return convertPointerType(pointeeType, kind.value(),
972-
/*templateArgument=*/true);
991+
return convertPointerType</*templateArgument=*/true>(pointeeType,
992+
kind.value());
973993
});
974994

975995
// Arbitrary optional types are not (yet) supported
@@ -978,19 +998,31 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
978998

979999
if (auto kind = classifyPointer(boundGenericType))
9801000
return withCache([&]() {
981-
return convertPointerType(argType, kind.value(),
982-
/*templateArgument=*/true);
1001+
return convertPointerType</*templateArgument=*/true>(argType,
1002+
kind.value());
9831003
});
9841004

9851005
if (auto width = classifySIMD(boundGenericType))
9861006
return withCache([&]() {
987-
return convertSIMDType(argType, width.value(),
988-
/*templateArgument=*/true);
1007+
return convertSIMDType</*templateArgument=*/true>(argType,
1008+
width.value());
9891009
});
9901010

9911011
return clang::QualType();
9921012
}
9931013

1014+
if (auto functionType = type->getAs<FunctionType>()) {
1015+
return withCache([&]() {
1016+
return visitFunctionType</*templateArgument=*/true>(functionType);
1017+
});
1018+
}
1019+
1020+
if (auto functionType = type->getAs<SILFunctionType>()) {
1021+
return withCache([&]() {
1022+
return visitSILFunctionType</*templateArgument=*/true>(functionType);
1023+
});
1024+
}
1025+
9941026
// Most types cannot be used to instantiate C++ function templates; give up.
9951027
return clang::QualType();
9961028
}

lib/AST/ClangTypeConverter.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,13 @@ class ClangTypeConverter :
7070
/// \returns The appropriate clang type on success, nullptr on failure.
7171
///
7272
/// Precondition: The representation argument must be C-compatible.
73-
const clang::Type *getFunctionType(
74-
ArrayRef<AnyFunctionType::Param> params, Type resultTy,
75-
AnyFunctionType::Representation repr);
73+
template <bool templateArgument>
74+
const clang::Type *getFunctionType(ArrayRef<AnyFunctionType::Param> params,
75+
Type resultTy,
76+
AnyFunctionType::Representation repr);
7677

7778
/// Compute the C function type for a SIL function type.
79+
template <bool templateArgument>
7880
const clang::Type *getFunctionType(ArrayRef<SILParameterInfo> params,
7981
std::optional<SILResultInfo> result,
8082
SILFunctionType::Representation repr);
@@ -125,11 +127,11 @@ class ClangTypeConverter :
125127

126128
clang::QualType convertClangDecl(Type type, const clang::Decl *decl);
127129

128-
clang::QualType convertSIMDType(CanType scalarType, unsigned width,
129-
bool templateArgument);
130+
template <bool templateArgument>
131+
clang::QualType convertSIMDType(CanType scalarType, unsigned width);
130132

131-
clang::QualType convertPointerType(CanType pointeeType, PointerKind kind,
132-
bool templateArgument);
133+
template <bool templateArgument>
134+
clang::QualType convertPointerType(CanType pointeeType, PointerKind kind);
133135

134136
void registerExportedClangDecl(Decl *swiftDecl,
135137
const clang::Decl *clangDecl);
@@ -148,6 +150,7 @@ class ClangTypeConverter :
148150
clang::QualType visitBoundGenericClassType(BoundGenericClassType *type);
149151
clang::QualType visitBoundGenericType(BoundGenericType *type);
150152
clang::QualType visitEnumType(EnumType *type);
153+
template <bool templateArgument = false>
151154
clang::QualType visitFunctionType(FunctionType *type);
152155
clang::QualType visitProtocolCompositionType(ProtocolCompositionType *type);
153156
clang::QualType visitExistentialType(ExistentialType *type);
@@ -156,6 +159,7 @@ class ClangTypeConverter :
156159
clang::QualType visitBuiltinFloatType(BuiltinFloatType *type);
157160
clang::QualType visitArchetypeType(ArchetypeType *type);
158161
clang::QualType visitDependentMemberType(DependentMemberType *type);
162+
template <bool templateArgument = false>
159163
clang::QualType visitSILFunctionType(SILFunctionType *type);
160164
clang::QualType visitGenericTypeParamType(GenericTypeParamType *type);
161165
clang::QualType visitDynamicSelfType(DynamicSelfType *type);

test/Interop/Cxx/templates/Inputs/function-templates.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,27 @@ template <class T> void expectsConstCharPtr(T str) { takesString(str); }
1515
template <long x> void hasNonTypeTemplateParameter() {}
1616
template <long x = 0> void hasDefaultedNonTypeTemplateParameter() {}
1717

18+
// NOTE: these will cause multi-def linker errors if used in more than one compilation unit
1819
int *intPtr;
19-
int (*functionPtr)(void);
20+
21+
int get42(void) { return 42; }
22+
int (*functionPtrGet42)(void) = &get42;
23+
int (*_Nonnull nonNullFunctionPtrGet42)(void) = &get42;
24+
25+
int tripleInt(int x) { return x * 3; }
26+
int (*functionPtrTripleInt)(int) = &tripleInt;
27+
int (*_Nonnull nonNullFunctionPtrTripleInt)(int) = &tripleInt;
28+
29+
int (^blockReturns111)(void) = ^{ return 111; };
30+
int (^_Nonnull nonNullBlockReturns222)(void) = ^{ return 222; };
31+
32+
int (^blockTripleInt)(int) = ^(int x) { return x * 3; };
33+
int (^_Nonnull nonNullBlockTripleInt)(int) = ^(int x) { return x * 3; };
34+
35+
// These functions construct block literals that capture a local variable, and
36+
// then feed those blocks back to Swift via the given Swift closure (cb).
37+
void getConstantIntBlock(int returnValue, void (^_Nonnull cb)(int (^_Nonnull)(void))) { cb(^{ return returnValue; }); }
38+
int getMultiplyIntBlock(int multiplier, int (^_Nonnull cb)(int (^_Nonnull)(int))) { return cb(^(int x) { return x * multiplier; }); }
2039

2140
// We cannot yet use this in Swift but, make sure we don't crash when parsing
2241
// it.
@@ -59,6 +78,7 @@ struct PlainStruct {
5978
struct CxxClass {
6079
int x;
6180
void method() {}
81+
int getX() const { return x; }
6282
};
6383

6484
struct __attribute__((swift_attr("import_reference")))
@@ -102,6 +122,21 @@ template <class T> void forwardingReference(T &&) {}
102122

103123
template <class T> void PointerTemplateParameter(T*){}
104124

125+
template <typename F> void callFunction(F f) { f(); }
126+
template <typename F, typename T> void callFunctionWithParam(F f, T t) { f(t); }
127+
template <typename F, typename T> T callFunctionWithReturn(F f) { return f(); }
128+
template <typename F, typename T> T callFunctionWithPassthrough(F f, T t) { return f(t); }
129+
130+
static inline void callBlock(void (^_Nonnull callback)(void)) { callback(); }
131+
template <typename F> void indirectlyCallFunction(F f) { callBlock(f); }
132+
template <typename F> void indirectlyCallFunctionTemplate(F f) { callFunction(f); }
133+
134+
static inline void callBlockWith42(void (^_Nonnull callback)(int)) { callback(42); }
135+
template <typename F> void indirectlyCallFunctionWith42(F f) { callBlockWith42(f); }
136+
137+
static inline void callBlockWithCxxClass24(void (^_Nonnull cb)(CxxClass)) { CxxClass c = {24}; cb(c); }
138+
template <typename F> void indirectlyCallFunctionWithCxxClass24(F f) { callBlockWithCxxClass24(f); }
139+
105140
namespace Orbiters {
106141

107142
template<class T>

0 commit comments

Comments
 (0)