Skip to content

🍒 [6.2] [cxx-interop] Allow C++ function templates to be instantiated with Swift closures #81214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6576,13 +6576,16 @@ const clang::Type *
ASTContext::getClangFunctionType(ArrayRef<AnyFunctionType::Param> params,
Type resultTy,
FunctionTypeRepresentation trueRep) {
return getClangTypeConverter().getFunctionType(params, resultTy, trueRep);
return getClangTypeConverter().getFunctionType(params, resultTy, trueRep,
/*templateArgument=*/false);
}

const clang::Type *ASTContext::getCanonicalClangFunctionType(
ArrayRef<SILParameterInfo> params, std::optional<SILResultInfo> result,
SILFunctionType::Representation trueRep) {
auto *ty = getClangTypeConverter().getFunctionType(params, result, trueRep);
auto *ty =
getClangTypeConverter().getFunctionType(params, result, trueRep,
/*templateArgument=*/false);
return ty ? ty->getCanonicalTypeInternal().getTypePtr() : nullptr;
}

Expand Down
70 changes: 49 additions & 21 deletions lib/AST/ClangTypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "clang/Basic/TargetInfo.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Compiler.h"

Expand Down Expand Up @@ -124,17 +125,18 @@ const clang::ASTContext &clangCtx) {

const clang::Type *ClangTypeConverter::getFunctionType(
ArrayRef<AnyFunctionType::Param> params, Type resultTy,
AnyFunctionType::Representation repr) {

auto resultClangTy = convert(resultTy);
AnyFunctionType::Representation repr, bool templateArgument) {
auto resultClangTy =
templateArgument ? convertTemplateArgument(resultTy) : convert(resultTy);
if (resultClangTy.isNull())
return nullptr;

SmallVector<clang::FunctionProtoType::ExtParameterInfo, 4> extParamInfos;
SmallVector<clang::QualType, 4> paramsClangTy;
bool someParamIsConsumed = false;
for (auto p : params) {
auto pc = convert(p.getPlainType());
auto pc = templateArgument ? convertTemplateArgument(p.getPlainType())
: convert(p.getPlainType());
if (pc.isNull())
return nullptr;
clang::FunctionProtoType::ExtParameterInfo extParamInfo;
Expand Down Expand Up @@ -165,16 +167,19 @@ const clang::Type *ClangTypeConverter::getFunctionType(
llvm_unreachable("invalid representation");
}

const clang::Type *
ClangTypeConverter::getFunctionType(ArrayRef<SILParameterInfo> params,
std::optional<SILResultInfo> result,
SILFunctionType::Representation repr) {

// Using the interface type is sufficient as type parameters get mapped to
// `id`, since ObjC lightweight generics use type erasure. (See also: SE-0057)
auto resultClangTy = result.has_value()
? convert(result.value().getInterfaceType())
: ClangASTContext.VoidTy;
const clang::Type *ClangTypeConverter::getFunctionType(
ArrayRef<SILParameterInfo> params, std::optional<SILResultInfo> result,
SILFunctionType::Representation repr, bool templateArgument) {
clang::QualType resultClangTy = ClangASTContext.VoidTy;
if (result) {
// Using the interface type is sufficient as type parameters get mapped to
// `id`, since ObjC lightweight generics use type erasure.
//
// (See also: SE-0057)
auto interfaceType = result->getInterfaceType();
resultClangTy = templateArgument ? convertTemplateArgument(interfaceType)
: convert(interfaceType);
}

if (resultClangTy.isNull())
return nullptr;
Expand All @@ -183,7 +188,8 @@ ClangTypeConverter::getFunctionType(ArrayRef<SILParameterInfo> params,
SmallVector<clang::QualType, 4> paramsClangTy;
bool someParamIsConsumed = false;
for (auto &p : params) {
auto pc = convert(p.getInterfaceType());
auto pc = templateArgument ? convertTemplateArgument(p.getInterfaceType())
: convert(p.getInterfaceType());
if (pc.isNull())
return nullptr;
clang::FunctionProtoType::ExtParameterInfo extParamInfo;
Expand Down Expand Up @@ -651,7 +657,8 @@ clang::QualType ClangTypeConverter::visitEnumType(EnumType *type) {
return convert(type->getDecl()->getRawType());
}

clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type) {
clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type,
bool templateArgument) {
const clang::Type *clangTy = nullptr;
auto repr = type->getRepresentation();
bool useClangTypes = type->getASTContext().LangOpts.UseClangFunctionTypes;
Expand All @@ -665,12 +672,15 @@ clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type) {
auto newRepr = (repr == FunctionTypeRepresentation::Swift
? FunctionTypeRepresentation::Block
: repr);
clangTy = getFunctionType(type->getParams(), type->getResult(), newRepr);
clangTy = getFunctionType(type->getParams(), type->getResult(), newRepr,
templateArgument);
}
return clang::QualType(clangTy, 0);
}

clang::QualType ClangTypeConverter::visitSILFunctionType(SILFunctionType *type) {
clang::QualType
ClangTypeConverter::visitSILFunctionType(SILFunctionType *type,
bool templateArgument) {
const clang::Type *clangTy = nullptr;
auto repr = type->getRepresentation();
bool useClangTypes = type->getASTContext().LangOpts.UseClangFunctionTypes;
Expand All @@ -688,7 +698,8 @@ clang::QualType ClangTypeConverter::visitSILFunctionType(SILFunctionType *type)
auto optionalResult = results.empty()
? std::nullopt
: std::optional<SILResultInfo>(results[0]);
clangTy = getFunctionType(type->getParameters(), optionalResult, newRepr);
clangTy = getFunctionType(type->getParameters(), optionalResult, newRepr,
templateArgument);
}
return clang::QualType(clangTy, 0);
}
Expand Down Expand Up @@ -933,6 +944,13 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
if (auto floatType = type->getAs<BuiltinFloatType>())
return withCache([&]() { return visitBuiltinFloatType(floatType); });

if (auto tupleType = type->getAs<TupleType>()) {
// We do not call visitTupleType() because we cannot yet handle tuples with
// a non-zero number of elements.
if (tupleType->getNumElements() == 0)
return ClangASTContext.VoidTy;
}

if (auto structType = type->getAs<StructType>()) {
// Swift structs are not supported in general, but some foreign types are
// imported as Swift structs. We reverse that mapping here.
Expand All @@ -953,8 +971,6 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
return withCache([&]() { return reverseBuiltinTypeMapping(structType); });
}

// TODO: function pointers are not yet supported, but they should be.

if (auto boundGenericType = type->getAs<BoundGenericType>()) {
if (boundGenericType->getGenericArgs().size() != 1)
// Must've got something other than a T?, *Pointer<T>, or SIMD*<T>
Expand Down Expand Up @@ -991,6 +1007,18 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
return clang::QualType();
}

if (auto functionType = type->getAs<FunctionType>()) {
return withCache([&]() {
return visitFunctionType(functionType, /*templateArgument=*/true);
});
}

if (auto functionType = type->getAs<SILFunctionType>()) {
return withCache([&]() {
return visitSILFunctionType(functionType, /*templateArgument=*/true);
});
}

// Most types cannot be used to instantiate C++ function templates; give up.
return clang::QualType();
}
Expand Down
16 changes: 10 additions & 6 deletions lib/AST/ClangTypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,16 @@ class ClangTypeConverter :
/// \returns The appropriate clang type on success, nullptr on failure.
///
/// Precondition: The representation argument must be C-compatible.
const clang::Type *getFunctionType(
ArrayRef<AnyFunctionType::Param> params, Type resultTy,
AnyFunctionType::Representation repr);
const clang::Type *getFunctionType(ArrayRef<AnyFunctionType::Param> params,
Type resultTy,
AnyFunctionType::Representation repr,
bool templateArgument);

/// Compute the C function type for a SIL function type.
const clang::Type *getFunctionType(ArrayRef<SILParameterInfo> params,
std::optional<SILResultInfo> result,
SILFunctionType::Representation repr);
SILFunctionType::Representation repr,
bool templateArgument);

/// Check whether the given Clang declaration is an export of a Swift
/// declaration introduced by this converter, and if so, return the original
Expand Down Expand Up @@ -148,15 +150,17 @@ class ClangTypeConverter :
clang::QualType visitBoundGenericClassType(BoundGenericClassType *type);
clang::QualType visitBoundGenericType(BoundGenericType *type);
clang::QualType visitEnumType(EnumType *type);
clang::QualType visitFunctionType(FunctionType *type);
clang::QualType visitFunctionType(FunctionType *type,
bool templateArgument = false);
clang::QualType visitProtocolCompositionType(ProtocolCompositionType *type);
clang::QualType visitExistentialType(ExistentialType *type);
clang::QualType visitBuiltinRawPointerType(BuiltinRawPointerType *type);
clang::QualType visitBuiltinIntegerType(BuiltinIntegerType *type);
clang::QualType visitBuiltinFloatType(BuiltinFloatType *type);
clang::QualType visitArchetypeType(ArchetypeType *type);
clang::QualType visitDependentMemberType(DependentMemberType *type);
clang::QualType visitSILFunctionType(SILFunctionType *type);
clang::QualType visitSILFunctionType(SILFunctionType *type,
bool templateArgument = false);
clang::QualType visitGenericTypeParamType(GenericTypeParamType *type);
clang::QualType visitDynamicSelfType(DynamicSelfType *type);
clang::QualType visitSILBlockStorageType(SILBlockStorageType *type);
Expand Down
37 changes: 36 additions & 1 deletion test/Interop/Cxx/templates/Inputs/function-templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,27 @@ template <class T> void expectsConstCharPtr(T str) { takesString(str); }
template <long x> void hasNonTypeTemplateParameter() {}
template <long x = 0> void hasDefaultedNonTypeTemplateParameter() {}

// NOTE: these will cause multi-def linker errors if used in more than one compilation unit
int *intPtr;
int (*functionPtr)(void);

int get42(void) { return 42; }
int (*functionPtrGet42)(void) = &get42;
int (*_Nonnull nonNullFunctionPtrGet42)(void) = &get42;

int tripleInt(int x) { return x * 3; }
int (*functionPtrTripleInt)(int) = &tripleInt;
int (*_Nonnull nonNullFunctionPtrTripleInt)(int) = &tripleInt;

int (^blockReturns111)(void) = ^{ return 111; };
int (^_Nonnull nonNullBlockReturns222)(void) = ^{ return 222; };

int (^blockTripleInt)(int) = ^(int x) { return x * 3; };
int (^_Nonnull nonNullBlockTripleInt)(int) = ^(int x) { return x * 3; };

// These functions construct block literals that capture a local variable, and
// then feed those blocks back to Swift via the given Swift closure (cb).
void getConstantIntBlock(int returnValue, void (^_Nonnull cb)(int (^_Nonnull)(void))) { cb(^{ return returnValue; }); }
int getMultiplyIntBlock(int multiplier, int (^_Nonnull cb)(int (^_Nonnull)(int))) { return cb(^(int x) { return x * multiplier; }); }

// We cannot yet use this in Swift but, make sure we don't crash when parsing
// it.
Expand Down Expand Up @@ -59,6 +78,7 @@ struct PlainStruct {
struct CxxClass {
int x;
void method() {}
int getX() const { return x; }
};

struct __attribute__((swift_attr("import_reference")))
Expand Down Expand Up @@ -102,6 +122,21 @@ template <class T> void forwardingReference(T &&) {}

template <class T> void PointerTemplateParameter(T*){}

template <typename F> void callFunction(F f) { f(); }
template <typename F, typename T> void callFunctionWithParam(F f, T t) { f(t); }
template <typename F, typename T> T callFunctionWithReturn(F f) { return f(); }
template <typename F, typename T> T callFunctionWithPassthrough(F f, T t) { return f(t); }

static inline void callBlock(void (^_Nonnull callback)(void)) { callback(); }
template <typename F> void indirectlyCallFunction(F f) { callBlock(f); }
template <typename F> void indirectlyCallFunctionTemplate(F f) { callFunction(f); }

static inline void callBlockWith42(void (^_Nonnull callback)(int)) { callback(42); }
template <typename F> void indirectlyCallFunctionWith42(F f) { callBlockWith42(f); }

static inline void callBlockWithCxxClass24(void (^_Nonnull cb)(CxxClass)) { CxxClass c = {24}; cb(c); }
template <typename F> void indirectlyCallFunctionWithCxxClass24(F f) { callBlockWithCxxClass24(f); }

namespace Orbiters {

template<class T>
Expand Down
Loading