Skip to content

[cxx-interop] Allow C++ function templates to be instantiated with Swift closures #81016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6576,13 +6576,16 @@ const clang::Type *
ASTContext::getClangFunctionType(ArrayRef<AnyFunctionType::Param> params,
Type resultTy,
FunctionTypeRepresentation trueRep) {
return getClangTypeConverter().getFunctionType(params, resultTy, trueRep);
return getClangTypeConverter().getFunctionType</*templateArgument=*/false>(
params, resultTy, trueRep);
}

const clang::Type *ASTContext::getCanonicalClangFunctionType(
ArrayRef<SILParameterInfo> params, std::optional<SILResultInfo> result,
SILFunctionType::Representation trueRep) {
auto *ty = getClangTypeConverter().getFunctionType(params, result, trueRep);
auto *ty =
getClangTypeConverter().getFunctionType</*templateArgument=*/false>(
params, result, trueRep);
return ty ? ty->getCanonicalTypeInternal().getTypePtr() : nullptr;
}

Expand Down
94 changes: 63 additions & 31 deletions lib/AST/ClangTypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "clang/Basic/TargetInfo.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Compiler.h"

Expand Down Expand Up @@ -122,19 +123,22 @@ const clang::ASTContext &clangCtx) {

} // end anonymous namespace

const clang::Type *ClangTypeConverter::getFunctionType(
ArrayRef<AnyFunctionType::Param> params, Type resultTy,
AnyFunctionType::Representation repr) {

auto resultClangTy = convert(resultTy);
template <bool templateArgument>
const clang::Type *
ClangTypeConverter::getFunctionType(ArrayRef<AnyFunctionType::Param> params,
Type resultTy,
AnyFunctionType::Representation repr) {
auto resultClangTy =
templateArgument ? convertTemplateArgument(resultTy) : convert(resultTy);
if (resultClangTy.isNull())
return nullptr;

SmallVector<clang::FunctionProtoType::ExtParameterInfo, 4> extParamInfos;
SmallVector<clang::QualType, 4> paramsClangTy;
bool someParamIsConsumed = false;
for (auto p : params) {
auto pc = convert(p.getPlainType());
auto pc = templateArgument ? convertTemplateArgument(p.getPlainType())
: convert(p.getPlainType());
if (pc.isNull())
return nullptr;
clang::FunctionProtoType::ExtParameterInfo extParamInfo;
Expand Down Expand Up @@ -165,16 +169,21 @@ const clang::Type *ClangTypeConverter::getFunctionType(
llvm_unreachable("invalid representation");
}

template <bool templateArgument>
const clang::Type *
ClangTypeConverter::getFunctionType(ArrayRef<SILParameterInfo> params,
std::optional<SILResultInfo> result,
SILFunctionType::Representation repr) {

// Using the interface type is sufficient as type parameters get mapped to
// `id`, since ObjC lightweight generics use type erasure. (See also: SE-0057)
auto resultClangTy = result.has_value()
? convert(result.value().getInterfaceType())
: ClangASTContext.VoidTy;
clang::QualType resultClangTy = ClangASTContext.VoidTy;
if (result) {
// Using the interface type is sufficient as type parameters get mapped to
// `id`, since ObjC lightweight generics use type erasure.
//
// (See also: SE-0057)
auto interfaceType = result->getInterfaceType();
resultClangTy = templateArgument ? convertTemplateArgument(interfaceType)
: convert(interfaceType);
}

if (resultClangTy.isNull())
return nullptr;
Expand All @@ -183,7 +192,8 @@ ClangTypeConverter::getFunctionType(ArrayRef<SILParameterInfo> params,
SmallVector<clang::QualType, 4> paramsClangTy;
bool someParamIsConsumed = false;
for (auto &p : params) {
auto pc = convert(p.getInterfaceType());
auto pc = templateArgument ? convertTemplateArgument(p.getInterfaceType())
: convert(p.getInterfaceType());
if (pc.isNull())
return nullptr;
clang::FunctionProtoType::ExtParameterInfo extParamInfo;
Expand Down Expand Up @@ -565,18 +575,18 @@ ClangTypeConverter::visitBoundGenericType(BoundGenericType *type) {
}

if (auto kind = classifyPointer(type))
return convertPointerType(argType, kind.value(),
/*templateArgument=*/false);
return convertPointerType</*templateArgument=*/false>(argType,
kind.value());

if (auto width = classifySIMD(type))
return convertSIMDType(argType, width.value(), /*templateArgument=*/false);
return convertSIMDType</*templateArgument=*/false>(argType, width.value());

return clang::QualType();
}

template <bool templateArgument>
clang::QualType ClangTypeConverter::convertSIMDType(CanType scalarType,
unsigned width,
bool templateArgument) {
unsigned width) {
clang::QualType scalarTy = templateArgument
? convertTemplateArgument(scalarType)
: convert(scalarType);
Expand All @@ -588,9 +598,9 @@ clang::QualType ClangTypeConverter::convertSIMDType(CanType scalarType,
return vectorTy;
}

template <bool templateArgument>
clang::QualType ClangTypeConverter::convertPointerType(CanType pointeeType,
PointerKind kind,
bool templateArgument) {
PointerKind kind) {
switch (kind) {
case PointerKind::Unmanaged:
return templateArgument ? clang::QualType() : convert(pointeeType);
Expand Down Expand Up @@ -651,6 +661,7 @@ clang::QualType ClangTypeConverter::visitEnumType(EnumType *type) {
return convert(type->getDecl()->getRawType());
}

template <bool templateArgument>
clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type) {
const clang::Type *clangTy = nullptr;
auto repr = type->getRepresentation();
Expand All @@ -665,12 +676,15 @@ clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type) {
auto newRepr = (repr == FunctionTypeRepresentation::Swift
? FunctionTypeRepresentation::Block
: repr);
clangTy = getFunctionType(type->getParams(), type->getResult(), newRepr);
clangTy = getFunctionType<templateArgument>(type->getParams(),
type->getResult(), newRepr);
}
return clang::QualType(clangTy, 0);
}

clang::QualType ClangTypeConverter::visitSILFunctionType(SILFunctionType *type) {
template <bool templateArgument>
clang::QualType
ClangTypeConverter::visitSILFunctionType(SILFunctionType *type) {
const clang::Type *clangTy = nullptr;
auto repr = type->getRepresentation();
bool useClangTypes = type->getASTContext().LangOpts.UseClangFunctionTypes;
Expand All @@ -688,7 +702,8 @@ clang::QualType ClangTypeConverter::visitSILFunctionType(SILFunctionType *type)
auto optionalResult = results.empty()
? std::nullopt
: std::optional<SILResultInfo>(results[0]);
clangTy = getFunctionType(type->getParameters(), optionalResult, newRepr);
clangTy = getFunctionType<templateArgument>(type->getParameters(),
optionalResult, newRepr);
}
return clang::QualType(clangTy, 0);
}
Expand Down Expand Up @@ -933,6 +948,13 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
if (auto floatType = type->getAs<BuiltinFloatType>())
return withCache([&]() { return visitBuiltinFloatType(floatType); });

if (auto tupleType = type->getAs<TupleType>()) {
// We do not call visitTupleType() because we cannot yet handle tuples with
// a non-zero number of elements.
if (tupleType->getNumElements() == 0)
return ClangASTContext.VoidTy;
}

if (auto structType = type->getAs<StructType>()) {
// Swift structs are not supported in general, but some foreign types are
// imported as Swift structs. We reverse that mapping here.
Expand All @@ -953,8 +975,6 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
return withCache([&]() { return reverseBuiltinTypeMapping(structType); });
}

// TODO: function pointers are not yet supported, but they should be.

if (auto boundGenericType = type->getAs<BoundGenericType>()) {
if (boundGenericType->getGenericArgs().size() != 1)
// Must've got something other than a T?, *Pointer<T>, or SIMD*<T>
Expand All @@ -968,8 +988,8 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
auto pointeeType = argType->getAs<BoundGenericType>()
->getGenericArgs()[0]
->getCanonicalType();
return convertPointerType(pointeeType, kind.value(),
/*templateArgument=*/true);
return convertPointerType</*templateArgument=*/true>(pointeeType,
kind.value());
});

// Arbitrary optional types are not (yet) supported
Expand All @@ -978,19 +998,31 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {

if (auto kind = classifyPointer(boundGenericType))
return withCache([&]() {
return convertPointerType(argType, kind.value(),
/*templateArgument=*/true);
return convertPointerType</*templateArgument=*/true>(argType,
kind.value());
});

if (auto width = classifySIMD(boundGenericType))
return withCache([&]() {
return convertSIMDType(argType, width.value(),
/*templateArgument=*/true);
return convertSIMDType</*templateArgument=*/true>(argType,
width.value());
});

return clang::QualType();
}

if (auto functionType = type->getAs<FunctionType>()) {
return withCache([&]() {
return visitFunctionType</*templateArgument=*/true>(functionType);
});
}

if (auto functionType = type->getAs<SILFunctionType>()) {
return withCache([&]() {
return visitSILFunctionType</*templateArgument=*/true>(functionType);
});
}

// Most types cannot be used to instantiate C++ function templates; give up.
return clang::QualType();
}
Expand Down
18 changes: 11 additions & 7 deletions lib/AST/ClangTypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ class ClangTypeConverter :
/// \returns The appropriate clang type on success, nullptr on failure.
///
/// Precondition: The representation argument must be C-compatible.
const clang::Type *getFunctionType(
ArrayRef<AnyFunctionType::Param> params, Type resultTy,
AnyFunctionType::Representation repr);
template <bool templateArgument>
const clang::Type *getFunctionType(ArrayRef<AnyFunctionType::Param> params,
Type resultTy,
AnyFunctionType::Representation repr);

/// Compute the C function type for a SIL function type.
template <bool templateArgument>
const clang::Type *getFunctionType(ArrayRef<SILParameterInfo> params,
std::optional<SILResultInfo> result,
SILFunctionType::Representation repr);
Expand Down Expand Up @@ -125,11 +127,11 @@ class ClangTypeConverter :

clang::QualType convertClangDecl(Type type, const clang::Decl *decl);

clang::QualType convertSIMDType(CanType scalarType, unsigned width,
bool templateArgument);
template <bool templateArgument>
clang::QualType convertSIMDType(CanType scalarType, unsigned width);

clang::QualType convertPointerType(CanType pointeeType, PointerKind kind,
bool templateArgument);
template <bool templateArgument>
clang::QualType convertPointerType(CanType pointeeType, PointerKind kind);

void registerExportedClangDecl(Decl *swiftDecl,
const clang::Decl *clangDecl);
Expand All @@ -148,6 +150,7 @@ class ClangTypeConverter :
clang::QualType visitBoundGenericClassType(BoundGenericClassType *type);
clang::QualType visitBoundGenericType(BoundGenericType *type);
clang::QualType visitEnumType(EnumType *type);
template <bool templateArgument = false>
clang::QualType visitFunctionType(FunctionType *type);
clang::QualType visitProtocolCompositionType(ProtocolCompositionType *type);
clang::QualType visitExistentialType(ExistentialType *type);
Expand All @@ -156,6 +159,7 @@ class ClangTypeConverter :
clang::QualType visitBuiltinFloatType(BuiltinFloatType *type);
clang::QualType visitArchetypeType(ArchetypeType *type);
clang::QualType visitDependentMemberType(DependentMemberType *type);
template <bool templateArgument = false>
clang::QualType visitSILFunctionType(SILFunctionType *type);
clang::QualType visitGenericTypeParamType(GenericTypeParamType *type);
clang::QualType visitDynamicSelfType(DynamicSelfType *type);
Expand Down
37 changes: 36 additions & 1 deletion test/Interop/Cxx/templates/Inputs/function-templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,27 @@ template <class T> void expectsConstCharPtr(T str) { takesString(str); }
template <long x> void hasNonTypeTemplateParameter() {}
template <long x = 0> void hasDefaultedNonTypeTemplateParameter() {}

// NOTE: these will cause multi-def linker errors if used in more than one compilation unit
int *intPtr;
int (*functionPtr)(void);

int get42(void) { return 42; }
int (*functionPtrGet42)(void) = &get42;
int (*_Nonnull nonNullFunctionPtrGet42)(void) = &get42;

int tripleInt(int x) { return x * 3; }
int (*functionPtrTripleInt)(int) = &tripleInt;
int (*_Nonnull nonNullFunctionPtrTripleInt)(int) = &tripleInt;

int (^blockReturns111)(void) = ^{ return 111; };
int (^_Nonnull nonNullBlockReturns222)(void) = ^{ return 222; };

int (^blockTripleInt)(int) = ^(int x) { return x * 3; };
int (^_Nonnull nonNullBlockTripleInt)(int) = ^(int x) { return x * 3; };

// These functions construct block literals that capture a local variable, and
// then feed those blocks back to Swift via the given Swift closure (cb).
void getConstantIntBlock(int returnValue, void (^_Nonnull cb)(int (^_Nonnull)(void))) { cb(^{ return returnValue; }); }
int getMultiplyIntBlock(int multiplier, int (^_Nonnull cb)(int (^_Nonnull)(int))) { return cb(^(int x) { return x * multiplier; }); }

// We cannot yet use this in Swift but, make sure we don't crash when parsing
// it.
Expand Down Expand Up @@ -59,6 +78,7 @@ struct PlainStruct {
struct CxxClass {
int x;
void method() {}
int getX() const { return x; }
};

struct __attribute__((swift_attr("import_reference")))
Expand Down Expand Up @@ -102,6 +122,21 @@ template <class T> void forwardingReference(T &&) {}

template <class T> void PointerTemplateParameter(T*){}

template <typename F> void callFunction(F f) { f(); }
template <typename F, typename T> void callFunctionWithParam(F f, T t) { f(t); }
template <typename F, typename T> T callFunctionWithReturn(F f) { return f(); }
template <typename F, typename T> T callFunctionWithPassthrough(F f, T t) { return f(t); }

static inline void callBlock(void (^_Nonnull callback)(void)) { callback(); }
template <typename F> void indirectlyCallFunction(F f) { callBlock(f); }
template <typename F> void indirectlyCallFunctionTemplate(F f) { callFunction(f); }

static inline void callBlockWith42(void (^_Nonnull callback)(int)) { callback(42); }
template <typename F> void indirectlyCallFunctionWith42(F f) { callBlockWith42(f); }

static inline void callBlockWithCxxClass24(void (^_Nonnull cb)(CxxClass)) { CxxClass c = {24}; cb(c); }
template <typename F> void indirectlyCallFunctionWithCxxClass24(F f) { callBlockWithCxxClass24(f); }

namespace Orbiters {

template<class T>
Expand Down
Loading