Skip to content

Commit 1a449f5

Browse files
committed
[interop][SwiftToCxx] emit members in extensions for same-type
1 parent 0fea0c4 commit 1a449f5

File tree

5 files changed

+89
-4
lines changed

5 files changed

+89
-4
lines changed

lib/PrintAsClang/DeclAndTypePrinter.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,17 @@ class DeclAndTypePrinter::Implementation
348348
// FIXME: Print struct's availability.
349349
ClangValueTypePrinter printer(os, owningPrinter.prologueOS,
350350
owningPrinter.interopContext);
351-
printer.printValueTypeDecl(
352-
SD, /*bodyPrinter=*/[&]() { printMembers(SD->getMembers()); });
351+
printer.printValueTypeDecl(SD, /*bodyPrinter=*/[&]() {
352+
printMembers(SD->getMembers());
353+
for (const auto *ed :
354+
owningPrinter.interopContext.getExtensionsForNominalType(SD)) {
355+
auto sign = ed->getGenericSignature();
356+
// FIXME: support requirements.
357+
if (!sign.getRequirements().empty())
358+
continue;
359+
printMembers(ed->getMembers());
360+
}
361+
});
353362
}
354363

355364
void visitExtensionDecl(ExtensionDecl *ED) {
@@ -881,7 +890,11 @@ class DeclAndTypePrinter::Implementation
881890
if (isClassMethod)
882891
return;
883892
assert(!AFD->isStatic());
884-
auto *typeDeclContext = cast<NominalTypeDecl>(AFD->getParent());
893+
auto *typeDeclContext = dyn_cast<NominalTypeDecl>(AFD->getParent());
894+
if (!typeDeclContext) {
895+
typeDeclContext =
896+
cast<ExtensionDecl>(AFD->getParent())->getExtendedNominal();
897+
}
885898

886899
std::string cFuncDecl;
887900
llvm::raw_string_ostream cFuncPrologueOS(cFuncDecl);

lib/PrintAsClang/ModuleContentsWriter.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,13 @@ class ModuleWriter {
452452
bool writeStruct(const StructDecl *SD) {
453453
if (addImport(SD))
454454
return true;
455-
if (outputLangMode == OutputLanguageMode::Cxx)
455+
if (outputLangMode == OutputLanguageMode::Cxx) {
456456
(void)forwardDeclareMemberTypes(SD->getMembers(), SD);
457+
for (const auto *ed :
458+
printer.getInteropContext().getExtensionsForNominalType(SD)) {
459+
(void)forwardDeclareMemberTypes(ed->getMembers(), SD);
460+
}
461+
}
457462
printer.print(SD);
458463
return true;
459464
}
@@ -555,6 +560,8 @@ class ModuleWriter {
555560
return !printer.shouldInclude(VD);
556561

557562
if (auto ED = dyn_cast<ExtensionDecl>(D)) {
563+
if (outputLangMode == OutputLanguageMode::Cxx)
564+
return false;
558565
auto baseClass = ED->getSelfClassDecl();
559566
return !baseClass || !printer.shouldInclude(baseClass) ||
560567
baseClass->isForeign();
@@ -580,6 +587,8 @@ class ModuleWriter {
580587

581588
if (auto ED = dyn_cast<ExtensionDecl>(D)) {
582589
auto baseClass = ED->getSelfClassDecl();
590+
if (!baseClass)
591+
return ED->getExtendedNominal()->getName().str();
583592
return baseClass->getName().str();
584593
}
585594
llvm_unreachable("unknown top-level ObjC decl");
@@ -632,6 +641,16 @@ class ModuleWriter {
632641
assert(declsToWrite.empty());
633642
declsToWrite.assign(decls.begin(), decls.end());
634643

644+
if (outputLangMode == OutputLanguageMode::Cxx) {
645+
for (const Decl *D : declsToWrite) {
646+
if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
647+
const auto *type = ED->getExtendedNominal();
648+
if (isa<StructDecl>(type))
649+
printer.getInteropContext().recordExtensions(type, ED);
650+
}
651+
}
652+
}
653+
635654
while (!declsToWrite.empty()) {
636655
const Decl *D = declsToWrite.back();
637656
bool success = true;

lib/PrintAsClang/SwiftToClangInteropContext.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,19 @@ void SwiftToClangInteropContext::recordEmittedClangTypeDecl(
4040
assert(typeDecl->hasClangNode());
4141
referencedClangTypeDecls.insert(typeDecl);
4242
}
43+
44+
void SwiftToClangInteropContext::recordExtensions(
45+
const NominalTypeDecl *typeDecl, const ExtensionDecl *ext) {
46+
auto it = extensions.insert(
47+
std::make_pair(typeDecl, std::vector<const ExtensionDecl *>()));
48+
it.first->second.push_back(ext);
49+
}
50+
51+
llvm::ArrayRef<const ExtensionDecl *>
52+
SwiftToClangInteropContext::getExtensionsForNominalType(
53+
const NominalTypeDecl *typeDecl) const {
54+
auto exts = extensions.find(typeDecl);
55+
if (exts != extensions.end())
56+
return exts->getSecond();
57+
return {};
58+
}

lib/PrintAsClang/SwiftToClangInteropContext.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define SWIFT_PRINTASCLANG_SWIFTTOCLANGINTEROPCONTEXT_H
1515

1616
#include "llvm/ADT/ArrayRef.h"
17+
#include "llvm/ADT/DenseMap.h"
18+
#include "llvm/ADT/Optional.h"
1719
#include "llvm/ADT/STLExtras.h"
1820
#include "llvm/ADT/SetVector.h"
1921
#include "llvm/ADT/StringSet.h"
@@ -25,6 +27,7 @@ class Decl;
2527
class IRABIDetailsProvider;
2628
class IRGenOptions;
2729
class ModuleDecl;
30+
class ExtensionDecl;
2831
class NominalTypeDecl;
2932

3033
/// The \c SwiftToClangInteropContext class is responsible for providing
@@ -52,12 +55,20 @@ class SwiftToClangInteropContext {
5255
return referencedClangTypeDecls;
5356
}
5457

58+
void recordExtensions(const NominalTypeDecl *typeDecl,
59+
const ExtensionDecl *ext);
60+
61+
llvm::ArrayRef<const ExtensionDecl *>
62+
getExtensionsForNominalType(const NominalTypeDecl *typeDecl) const;
63+
5564
private:
5665
ModuleDecl &mod;
5766
const IRGenOptions &irGenOpts;
5867
std::unique_ptr<IRABIDetailsProvider> irABIDetails;
5968
llvm::StringSet<> emittedStubs;
6069
llvm::SetVector<const NominalTypeDecl *> referencedClangTypeDecls;
70+
llvm::DenseMap<const NominalTypeDecl *, std::vector<const ExtensionDecl *>>
71+
extensions;
6172
};
6273

6374
} // end namespace swift
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend %s -typecheck -module-name Structs -clang-header-expose-public-decls -emit-clang-header-path %t/structs.h
3+
// RUN: %FileCheck %s < %t/structs.h
4+
5+
// RUN: %check-interop-cxx-header-in-clang(%t/structs.h -Wno-unused-private-field -Wno-unused-function)
6+
7+
8+
public struct TypeAfterArray {
9+
var x: Int16
10+
}
11+
12+
public struct Array {
13+
public var x: Int
14+
}
15+
16+
extension Array {
17+
public var val: Structs.TypeAfterArray {
18+
return TypeAfterArray(x: 42)
19+
}
20+
}
21+
22+
// CHECK class TypeAfterArray;
23+
// CHECK: class Array final {
24+
// CHECK: swift::Int getX() const;
25+
// CHECK-NEXT: inline void setX(swift::Int value);
26+
// CHECK-NEXT: TypeAfterArray getVal() const;

0 commit comments

Comments
 (0)