Skip to content

Commit b4d7a0c

Browse files
committed
[interop][SwiftToCxx] bridge returned C++ record types back to C++ from Swift
1 parent 1d16403 commit b4d7a0c

File tree

10 files changed

+339
-32
lines changed

10 files changed

+339
-32
lines changed

lib/PrintAsClang/ModuleContentsWriter.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/SwiftNameTranslation.h"
2727
#include "swift/AST/TypeDeclFinder.h"
2828
#include "swift/ClangImporter/ClangImporter.h"
29+
#include "swift/Strings.h"
2930

3031
#include "clang/AST/Decl.h"
3132
#include "clang/Basic/Module.h"
@@ -178,6 +179,14 @@ class ModuleWriter {
178179
}
179180
}
180181

182+
if (outputLangMode == OutputLanguageMode::Cxx) {
183+
// Only add C++ imports in C++ mode for now.
184+
if (!D->hasClangNode())
185+
return true;
186+
if (otherModule->getName().str() == CLANG_HEADER_MODULE_NAME)
187+
return true;
188+
}
189+
181190
imports.insert(otherModule);
182191
return true;
183192
}
@@ -258,8 +267,10 @@ class ModuleWriter {
258267
if (outputLangMode == OutputLanguageMode::Cxx) {
259268
if (isa<StructDecl>(TD) || isa<EnumDecl>(TD)) {
260269
auto *NTD = cast<NominalTypeDecl>(TD);
261-
forwardDeclare(
262-
NTD, [&]() { ClangValueTypePrinter::forwardDeclType(os, NTD); });
270+
if (!addImport(NTD)) {
271+
forwardDeclare(
272+
NTD, [&]() { ClangValueTypePrinter::forwardDeclType(os, NTD); });
273+
}
263274
}
264275
return;
265276
}

lib/PrintAsClang/PrintAsClang.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,11 @@ static int compareImportModulesByName(const ImportModuleTy *left,
395395

396396
static void writeImports(raw_ostream &out,
397397
llvm::SmallPtrSetImpl<ImportModuleTy> &imports,
398-
ModuleDecl &M, StringRef bridgingHeader) {
399-
out << "#if __has_feature(modules)\n";
398+
ModuleDecl &M, StringRef bridgingHeader,
399+
bool useCxxImport = false) {
400+
// Note: we can't use has_feature(modules) as it's always enabled in C++20
401+
// mode.
402+
out << "#if __has_feature(objc_modules)\n";
400403

401404
out << "#if __has_warning(\"-Watimport-in-framework-header\")\n"
402405
<< "#pragma clang diagnostic ignored \"-Watimport-in-framework-header\"\n"
@@ -420,6 +423,8 @@ static void writeImports(raw_ostream &out,
420423
// Track printed names to handle overlay modules.
421424
llvm::SmallPtrSet<Identifier, 8> seenImports;
422425
bool includeUnderlying = false;
426+
StringRef importDirective =
427+
useCxxImport ? "#pragma clang module import" : "@import";
423428
for (auto import : sortedImports) {
424429
if (auto *swiftModule = import.dyn_cast<ModuleDecl *>()) {
425430
auto Name = swiftModule->getName();
@@ -428,12 +433,12 @@ static void writeImports(raw_ostream &out,
428433
continue;
429434
}
430435
if (seenImports.insert(Name).second)
431-
out << "@import " << Name.str() << ";\n";
436+
out << importDirective << ' ' << Name.str() << ";\n";
432437
} else {
433438
const auto *clangModule = import.get<const clang::Module *>();
434439
assert(clangModule->isSubModule() &&
435440
"top-level modules should use a normal swift::ModuleDecl");
436-
out << "@import ";
441+
out << importDirective << ' ';
437442
ModuleDecl::ReverseFullNameIterator(clangModule).printForward(out);
438443
out << ";\n";
439444
}
@@ -489,16 +494,14 @@ static std::string computeMacroGuard(const ModuleDecl *M) {
489494
return (llvm::Twine(M->getNameStr().upper()) + "_SWIFT_H").str();
490495
}
491496

492-
static std::string
493-
getModuleContentsCxxString(ModuleDecl &M,
494-
SwiftToClangInteropContext &interopContext,
495-
bool requiresExposedAttribute) {
496-
SmallPtrSet<ImportModuleTy, 8> imports;
497+
static std::string getModuleContentsCxxString(
498+
ModuleDecl &M, SmallPtrSet<ImportModuleTy, 8> &imports,
499+
SwiftToClangInteropContext &interopContext, bool requiresExposedAttribute) {
497500
std::string moduleContentsBuf;
498501
llvm::raw_string_ostream moduleContents{moduleContentsBuf};
499502
printModuleContentsAsCxx(moduleContents, imports, M, interopContext,
500503
requiresExposedAttribute);
501-
return moduleContents.str();
504+
return std::move(moduleContents.str());
502505
}
503506

504507
bool swift::printAsClangHeader(raw_ostream &os, ModuleDecl *M,
@@ -522,9 +525,13 @@ bool swift::printAsClangHeader(raw_ostream &os, ModuleDecl *M,
522525
// FIXME: Expose Swift with @expose by default.
523526
if (ExposePublicDeclsInClangHeader ||
524527
M->DeclContext::getASTContext().LangOpts.EnableCXXInterop) {
525-
os << getModuleContentsCxxString(
526-
*M, interopContext,
528+
SmallPtrSet<ImportModuleTy, 8> imports;
529+
auto contents = getModuleContentsCxxString(
530+
*M, imports, interopContext,
527531
/*requiresExposedAttribute=*/!ExposePublicDeclsInClangHeader);
532+
// FIXME: In ObjC++ mode, we do not need to reimport duplicate modules.
533+
writeImports(os, imports, *M, bridgingHeader, /*useCxxImport=*/true);
534+
os << contents;
528535
}
529536
});
530537
writeEpilogue(os);

lib/PrintAsClang/PrintClangFunction.cpp

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#include "swift/AST/TypeVisitor.h"
2828
#include "swift/ClangImporter/ClangImporter.h"
2929
#include "swift/IRGen/IRABIDetailsProvider.h"
30+
#include "clang/AST/ASTContext.h"
31+
#include "clang/AST/DeclTemplate.h"
32+
#include "clang/AST/NestedNameSpecifier.h"
3033
#include "llvm/ADT/STLExtras.h"
3134

3235
using namespace swift;
@@ -90,6 +93,74 @@ struct CFunctionSignatureTypePrinterModifierDelegate {
9093
mapValueTypeUseKind = None;
9194
};
9295

96+
class ClangTypeHandler {
97+
public:
98+
ClangTypeHandler(const clang::Decl *typeDecl) : typeDecl(typeDecl) {}
99+
100+
bool isRepresentable() const {
101+
// We can only return trivial types, or
102+
// types that can be moved or copied.
103+
if (auto *record = dyn_cast<clang::CXXRecordDecl>(typeDecl)) {
104+
return record->isTrivial() || record->hasMoveConstructor() ||
105+
record->hasCopyConstructorWithConstParam();
106+
}
107+
return false;
108+
}
109+
110+
void printTypeName(raw_ostream &os) const {
111+
auto &clangCtx = typeDecl->getASTContext();
112+
clang::PrintingPolicy pp(clangCtx.getLangOpts());
113+
const auto *NS = clang::NestedNameSpecifier::getRequiredQualification(
114+
clangCtx, clangCtx.getTranslationUnitDecl(),
115+
typeDecl->getLexicalDeclContext());
116+
if (NS)
117+
NS->print(os, pp);
118+
assert(cast<clang::NamedDecl>(typeDecl)->getDeclName().isIdentifier());
119+
os << cast<clang::NamedDecl>(typeDecl)->getName();
120+
if (auto *ctd =
121+
dyn_cast<clang::ClassTemplateSpecializationDecl>(typeDecl)) {
122+
if (ctd->getTemplateArgs().size()) {
123+
os << '<';
124+
llvm::interleaveComma(ctd->getTemplateArgs().asArray(), os,
125+
[&](const clang::TemplateArgument &arg) {
126+
arg.print(pp, os, /*IncludeType=*/true);
127+
});
128+
os << '>';
129+
}
130+
}
131+
}
132+
133+
void printReturnScaffold(raw_ostream &os,
134+
llvm::function_ref<void(StringRef)> bodyOfReturn) {
135+
std::string fullQualifiedType;
136+
std::string typeName;
137+
{
138+
llvm::raw_string_ostream typeNameOS(fullQualifiedType);
139+
printTypeName(typeNameOS);
140+
llvm::raw_string_ostream unqualTypeNameOS(typeName);
141+
unqualTypeNameOS << cast<clang::NamedDecl>(typeDecl)->getName();
142+
}
143+
os << "alignas(alignof(" << fullQualifiedType << ")) char storage[sizeof("
144+
<< fullQualifiedType << ")];\n";
145+
os << "auto * _Nonnull storageObjectPtr = reinterpret_cast<"
146+
<< fullQualifiedType << " *>(storage);\n";
147+
bodyOfReturn("storage");
148+
os << ";\n";
149+
auto *cxxRecord = cast<clang::CXXRecordDecl>(typeDecl);
150+
if (cxxRecord->isTrivial()) {
151+
// Trivial object can be just copied and not destroyed.
152+
os << "return *storageObjectPtr;\n";
153+
return;
154+
}
155+
os << fullQualifiedType << " result(std::move(*storageObjectPtr));\n";
156+
os << "storageObjectPtr->~" << typeName << "();\n";
157+
os << "return result;\n";
158+
}
159+
160+
private:
161+
const clang::Decl *typeDecl;
162+
};
163+
93164
// Prints types in the C function signature that corresponds to the
94165
// native Swift function/method.
95166
class CFunctionSignatureTypePrinter
@@ -235,6 +306,14 @@ class CFunctionSignatureTypePrinter
235306
if (languageMode != OutputLanguageMode::Cxx)
236307
return ClangRepresentation::unsupported;
237308

309+
if (decl->hasClangNode()) {
310+
ClangTypeHandler handler(decl->getClangDecl());
311+
if (!handler.isRepresentable())
312+
return ClangRepresentation::unsupported;
313+
handler.printTypeName(os);
314+
return ClangRepresentation::representable;
315+
}
316+
238317
// FIXME: Handle optional structures.
239318
if (typeUseKind == FunctionSignatureTypeUse::ParamType) {
240319
if (!isInOutParam) {
@@ -938,24 +1017,31 @@ void DeclAndTypeClangFunctionPrinter::printCxxThunkBody(
9381017
return;
9391018
}
9401019
if (auto *decl = resultTy->getNominalOrBoundGenericNominal()) {
1020+
auto valueTypeReturnThunker = [&](StringRef resultPointerName) {
1021+
if (auto directResultType = signature.getDirectResultType()) {
1022+
std::string typeEncoding =
1023+
encodeTypeInfo(*directResultType, moduleContext, typeMapping);
1024+
os << cxx_synthesis::getCxxImplNamespaceName()
1025+
<< "::swift_interop_returnDirect_" << typeEncoding << '('
1026+
<< resultPointerName << ", ";
1027+
printCallToCFunc(None);
1028+
os << ')';
1029+
} else {
1030+
printCallToCFunc(/*firstParam=*/resultPointerName);
1031+
}
1032+
};
1033+
if (decl->hasClangNode()) {
1034+
ClangTypeHandler handler(decl->getClangDecl());
1035+
assert(handler.isRepresentable());
1036+
handler.printReturnScaffold(os, valueTypeReturnThunker);
1037+
return;
1038+
}
9411039
ClangValueTypePrinter valueTypePrinter(os, cPrologueOS, interopContext);
9421040

9431041
valueTypePrinter.printValueTypeReturnScaffold(
9441042
decl, moduleContext,
9451043
[&]() { printTypeImplTypeSpecifier(resultTy, moduleContext); },
946-
[&](StringRef resultPointerName) {
947-
if (auto directResultType = signature.getDirectResultType()) {
948-
std::string typeEncoding =
949-
encodeTypeInfo(*directResultType, moduleContext, typeMapping);
950-
os << cxx_synthesis::getCxxImplNamespaceName()
951-
<< "::swift_interop_returnDirect_" << typeEncoding << '('
952-
<< resultPointerName << ", ";
953-
printCallToCFunc(None);
954-
os << ')';
955-
} else {
956-
printCallToCFunc(/*firstParam=*/resultPointerName);
957-
}
958-
});
1044+
valueTypeReturnThunker);
9591045
return;
9601046
}
9611047
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: split-file %s %t
3+
4+
// RUN: %target-swift-frontend -typecheck %t/use-cxx-types.swift -typecheck -module-name UseCxx -emit-clang-header-path %t/UseCxx.h -I %t -enable-experimental-cxx-interop -clang-header-expose-public-decls
5+
6+
// RUN: %target-interop-build-clangxx -c %t/use-swift-cxx-types.cpp -I %t -o %t/swift-cxx-execution.o -g
7+
// RUN: %target-interop-build-swift %t/use-cxx-types.swift -o %t/swift-cxx-execution -Xlinker %t/swift-cxx-execution.o -module-name UseCxx -Xfrontend -entry-point-function-name -Xfrontend swiftMain -I %t -g
8+
9+
// RUN: %target-codesign %t/swift-cxx-execution
10+
// RUN: %target-run %t/swift-cxx-execution | %FileCheck %s
11+
12+
// REQUIRES: executable_test
13+
14+
//--- header.h
15+
16+
extern "C" void puts(const char *);
17+
18+
struct Trivial {
19+
int x, y;
20+
21+
inline Trivial(int x, int y) : x(x), y(y) {}
22+
};
23+
24+
template<class T>
25+
struct NonTrivialTemplate {
26+
T x;
27+
28+
inline NonTrivialTemplate(T x) : x(x) {
29+
puts("create NonTrivialTemplate");
30+
}
31+
inline NonTrivialTemplate(const NonTrivialTemplate<T> &) = default;
32+
inline NonTrivialTemplate(NonTrivialTemplate<T> &&other) : x(static_cast<T &&>(other.x)) {
33+
puts("move NonTrivialTemplate");
34+
}
35+
inline ~NonTrivialTemplate() {
36+
puts("~NonTrivialTemplate");
37+
}
38+
};
39+
40+
//--- module.modulemap
41+
module CxxTest {
42+
header "header.h"
43+
requires cplusplus
44+
}
45+
46+
//--- use-cxx-types.swift
47+
import CxxTest
48+
49+
public func retNonTrivial(y: CInt) -> NonTrivialTemplate<Trivial> {
50+
return NonTrivialTemplate<Trivial>(Trivial(42, y))
51+
}
52+
53+
public func retTrivial(_ x: CInt) -> Trivial {
54+
return Trivial(x, -x)
55+
}
56+
57+
//--- use-swift-cxx-types.cpp
58+
59+
#include "header.h"
60+
#include "UseCxx.h"
61+
#include <assert.h>
62+
63+
int main() {
64+
{
65+
auto x = UseCxx::retTrivial(423421);
66+
assert(x.x == 423421);
67+
assert(x.y == -423421);
68+
}
69+
{
70+
auto x = UseCxx::retNonTrivial(-942);
71+
assert(x.x.y == -942);
72+
assert(x.x.x == 42);
73+
}
74+
// CHECK: create NonTrivialTemplate
75+
// CHECK-NEXT: move NonTrivialTemplate
76+
// CHECK-NEXT: ~NonTrivialTemplate
77+
// CHECK-NEXT: ~NonTrivialTemplate
78+
return 0;
79+
}

0 commit comments

Comments
 (0)