Skip to content

Commit a0d8ba2

Browse files
authored
Merge pull request #61913 from hyp/eng/generic-enum-init
[interop][SwiftToCxx] add support for constructing generic enum cases…
2 parents 7853184 + 8336edd commit a0d8ba2

File tree

11 files changed

+240
-102
lines changed

11 files changed

+240
-102
lines changed

docs/CppInteroperability/CppInteroperabilityStatus.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,4 @@ This status table describes which of the following Swift standard library APIs h
218218
|--------------------------------|----------------------------------------------------------|
219219
| `String` | Can be used as a type in C++. APIs in extensions are not exposed to C++. Conversion between `std.string` is not yet supported |
220220
| `Array<T>` | Can be used as a type in C++. Ranged for loops are supported. Limited set of APIs in some extensions are exposed to C++. |
221-
| `Optional<T>` | Can be used as a type in C++. `get` extracts the optional value and it's also implicitly castable to `bool`. Can't be constructed from C++ yet. |
221+
| `Optional<T>` | Can be used as a type in C++. Can be constructed. `get` extracts the optional value and it's also implicitly castable to `bool`. |

lib/PrintAsClang/DeclAndTypePrinter.cpp

Lines changed: 113 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ class DeclAndTypePrinter::Implementation
466466

467467
clangFuncPrinter.printCustomCxxFunction(
468468
{paramType},
469+
/*NeedsReturnTypes=*/true,
469470
[&](auto &types) {
470471
// Printing function name and return type
471472
os << " inline " << types[paramType] << " get" << name;
@@ -535,13 +536,13 @@ class DeclAndTypePrinter::Implementation
535536
outOfLineOS << " });\n ";
536537
}
537538
},
538-
ED->getModuleContext(), outOfLineOS);
539+
ED, ED->getModuleContext(), outOfLineOS);
539540
};
540541

541542
auto printStruct = [&](StringRef caseName, EnumElementDecl *elementDecl,
542543
Optional<IRABIDetailsProvider::EnumElementInfo>
543544
elementInfo) {
544-
os << " inline const static struct { "
545+
os << " inline const static struct _impl_" << caseName << " { "
545546
<< "// impl struct for case " << caseName << '\n';
546547
os << " inline constexpr operator cases() const {\n";
547548
os << " return cases::";
@@ -552,67 +553,112 @@ class DeclAndTypePrinter::Implementation
552553
if (elementDecl != nullptr) {
553554
assert(elementInfo.hasValue());
554555

555-
Type paramType, objectType;
556-
NominalTypeDecl *objectTypeDecl = nullptr;
557-
OptionalTypeKind optKind;
556+
Type paramType;
558557

559558
// TODO: support tuple type
560559
if (elementDecl->hasAssociatedValues() &&
561560
elementDecl->getParameterList()->size() == 1) {
562-
paramType = elementDecl->getParameterList()->front()->getType();
563-
std::tie(objectType, optKind) = getObjectTypeAndOptionality(
564-
paramType->getNominalOrBoundGenericNominal(), paramType);
565-
objectTypeDecl = objectType->getNominalOrBoundGenericNominal();
561+
paramType =
562+
elementDecl->getParameterList()->front()->getInterfaceType();
566563
}
567564

568565
SmallVector<Type> neededTypes;
569566
if (paramType) {
570567
neededTypes.push_back(paramType);
571568
}
572569

573-
// FIXME: support generic constructor.
574-
if (!ED->isGeneric())
575-
clangFuncPrinter.printCustomCxxFunction(
576-
neededTypes,
577-
[&](auto &types) {
578-
// Printing function name and return type
579-
os << " inline ";
580-
syntaxPrinter.printBaseName(elementDecl->getParentEnum());
581-
os << " operator()";
582-
583-
outOfLineOS << " inline ";
584-
outOfLineSyntaxPrinter.printBaseName(
585-
elementDecl->getParentEnum());
586-
outOfLineOS << ' ';
587-
outOfLineSyntaxPrinter.printBaseName(
588-
elementDecl->getParentEnum());
589-
outOfLineOS << "::_impl_" << elementDecl->getNameStr()
590-
<< "::operator()";
591-
},
592-
[&](auto &types) {
593-
// Printing parameters
594-
if (!paramType) {
595-
return;
596-
}
597-
assert(objectTypeDecl != nullptr);
598-
if (owningPrinter.typeMapping.getKnownCxxTypeInfo(
599-
objectTypeDecl)) {
600-
os << types[paramType] << " val";
601-
outOfLineOS << types[paramType] << " val";
570+
clangFuncPrinter.printCustomCxxFunction(
571+
neededTypes,
572+
/*NeedsReturnTypes=*/false,
573+
[&](auto &types) {
574+
const auto *ED = elementDecl->getParentEnum();
575+
// Printing function name and return type
576+
os << " inline ";
577+
syntaxPrinter.printNominalTypeReference(ED,
578+
ED->getModuleContext());
579+
os << " operator()";
580+
581+
outOfLineSyntaxPrinter
582+
.printNominalTypeOutsideMemberDeclTemplateSpecifiers(ED);
583+
outOfLineOS << " inline ";
584+
outOfLineSyntaxPrinter.printNominalTypeReference(
585+
ED, ED->getModuleContext());
586+
outOfLineOS << ' ';
587+
outOfLineSyntaxPrinter.printNominalTypeQualifier(
588+
ED, /*moduleContext=*/ED->getModuleContext());
589+
outOfLineOS << "_impl_" << caseName << "::operator()";
590+
},
591+
[&](auto &types) {
592+
// Printing parameters
593+
if (!paramType) {
594+
return;
595+
}
596+
os << types[paramType] << " val";
597+
outOfLineOS << types[paramType] << " val";
598+
},
599+
true,
600+
[&](auto &types) {
601+
auto *ED = elementDecl->getParentEnum();
602+
// Printing function body
603+
outOfLineOS << " auto result = ";
604+
outOfLineSyntaxPrinter.printNominalTypeQualifier(
605+
ED, ED->getModuleContext());
606+
outOfLineOS << "_make();\n";
607+
if (paramType) {
608+
if (paramType->getAs<GenericTypeParamType>()) {
609+
auto type = types[paramType];
610+
ClangSyntaxPrinter(outOfLineOS)
611+
.printIgnoredCxx17ExtensionDiagnosticBlock([&]() {
612+
// FIXME: handle C++ types.
613+
outOfLineOS << "if constexpr (std::is_base_of<::swift::"
614+
<< cxx_synthesis::getCxxImplNamespaceName()
615+
<< "::RefCountedClass, " << type
616+
<< ">::value) {\n";
617+
outOfLineOS << " void *ptr = ::swift::"
618+
<< cxx_synthesis::getCxxImplNamespaceName()
619+
<< "::_impl_RefCountedClass::"
620+
"copyOpaquePointer(val);\n";
621+
outOfLineOS
622+
<< " memcpy(result._getOpaquePointer(), &ptr, "
623+
"sizeof(ptr));\n";
624+
outOfLineOS << "} else if constexpr (::swift::"
625+
<< cxx_synthesis::getCxxImplNamespaceName()
626+
<< "::isValueType<" << type << ">) {\n";
627+
628+
outOfLineOS << " alignas(" << type;
629+
outOfLineOS << ") unsigned char buffer[sizeof(" << type;
630+
outOfLineOS << ")];\n";
631+
outOfLineOS << " auto *valCopy = new(buffer) "
632+
<< type;
633+
outOfLineOS << "(val);\n";
634+
outOfLineOS << " ";
635+
outOfLineOS << cxx_synthesis::getCxxSwiftNamespaceName()
636+
<< "::";
637+
outOfLineOS << cxx_synthesis::getCxxImplNamespaceName();
638+
outOfLineOS << "::implClassFor<" << type;
639+
outOfLineOS << ">::type::initializeWithTake(result._"
640+
"getOpaquePointer(), ";
641+
outOfLineOS << cxx_synthesis::getCxxSwiftNamespaceName()
642+
<< "::";
643+
outOfLineOS << cxx_synthesis::getCxxImplNamespaceName();
644+
outOfLineOS << "::implClassFor<" << type;
645+
outOfLineOS << ">::type::getOpaquePointer(*valCopy)";
646+
outOfLineOS << ");\n";
647+
outOfLineOS << "} else {\n";
648+
outOfLineOS
649+
<< " memcpy(result._getOpaquePointer(), &val, "
650+
"sizeof(val));\n";
651+
outOfLineOS << "}\n";
652+
});
602653
} else {
603-
os << "const " << types[paramType] << " &val";
604-
outOfLineOS << "const " << types[paramType] << " &val";
605-
}
606-
},
607-
true,
608-
[&](auto &types) {
609-
// Printing function body
610-
outOfLineOS << " auto result = ";
611-
outOfLineSyntaxPrinter.printBaseName(
612-
elementDecl->getParentEnum());
613-
outOfLineOS << "::_make();\n";
614-
615-
if (paramType) {
654+
655+
OptionalTypeKind optKind;
656+
Type objectType;
657+
std::tie(objectType, optKind) =
658+
DeclAndTypePrinter::getObjectTypeAndOptionality(
659+
ED, paramType);
660+
auto objectTypeDecl =
661+
objectType->getNominalOrBoundGenericNominal();
616662
assert(objectTypeDecl != nullptr);
617663

618664
if (owningPrinter.typeMapping.getKnownCxxTypeInfo(
@@ -621,6 +667,8 @@ class DeclAndTypePrinter::Implementation
621667
<< " memcpy(result._getOpaquePointer(), &val, "
622668
"sizeof(val));\n";
623669
} else {
670+
objectTypeDecl =
671+
paramType->getNominalOrBoundGenericNominal();
624672
outOfLineOS << " alignas(";
625673
outOfLineSyntaxPrinter
626674
.printModuleNamespaceQualifiersIfNeeded(
@@ -666,19 +714,20 @@ class DeclAndTypePrinter::Implementation
666714
outOfLineOS << ");\n";
667715
}
668716
}
669-
670-
outOfLineOS << " result._destructiveInjectEnumTag(";
671-
if (ED->isResilient()) {
672-
outOfLineOS << cxx_synthesis::getCxxImplNamespaceName()
673-
<< "::" << elementInfo->globalVariableName;
674-
} else {
675-
outOfLineOS << elementInfo->tag;
676-
}
677-
outOfLineOS << ");\n";
678-
outOfLineOS << " return result;\n";
679-
outOfLineOS << " ";
680-
},
681-
ED->getModuleContext(), outOfLineOS);
717+
}
718+
719+
outOfLineOS << " result._destructiveInjectEnumTag(";
720+
if (ED->isResilient()) {
721+
outOfLineOS << cxx_synthesis::getCxxImplNamespaceName()
722+
<< "::" << elementInfo->globalVariableName;
723+
} else {
724+
outOfLineOS << elementInfo->tag;
725+
}
726+
outOfLineOS << ");\n";
727+
outOfLineOS << " return result;\n";
728+
outOfLineOS << " ";
729+
},
730+
ED, ED->getModuleContext(), outOfLineOS);
682731
}
683732
os << " } ";
684733
syntaxPrinter.printIdentifier(caseName);

lib/PrintAsClang/PrintClangFunction.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,29 +1394,41 @@ bool DeclAndTypeClangFunctionPrinter::hasKnownOptionalNullableCxxMapping(
13941394
}
13951395

13961396
void DeclAndTypeClangFunctionPrinter::printCustomCxxFunction(
1397-
const SmallVector<Type> &neededTypes, PrinterTy retTypeAndNamePrinter,
1398-
PrinterTy paramPrinter, bool isConstFunc, PrinterTy bodyPrinter,
1399-
ModuleDecl *emittedModule, raw_ostream &outOfLineOS) {
1397+
const SmallVector<Type> &neededTypes, bool NeedsReturnTypes,
1398+
PrinterTy retTypeAndNamePrinter, PrinterTy paramPrinter, bool isConstFunc,
1399+
PrinterTy bodyPrinter, ValueDecl *valueDecl, ModuleDecl *emittedModule,
1400+
raw_ostream &outOfLineOS) {
14001401
llvm::MapVector<Type, std::string> types;
1402+
llvm::MapVector<Type, std::string> typeRefs;
14011403

14021404
for (auto &type : neededTypes) {
14031405
std::string typeStr;
14041406
llvm::raw_string_ostream typeOS(typeStr);
14051407
OptionalTypeKind optKind;
14061408
Type objectType;
14071409
std::tie(objectType, optKind) =
1408-
DeclAndTypePrinter::getObjectTypeAndOptionality(
1409-
type->getNominalOrBoundGenericNominal(), type);
1410+
DeclAndTypePrinter::getObjectTypeAndOptionality(valueDecl, type);
14101411

1411-
// Use FunctionSignatureTypeUse::ReturnType to avoid printing extra const or
1412-
// references
14131412
CFunctionSignatureTypePrinter typePrinter(
14141413
typeOS, cPrologueOS, typeMapping, OutputLanguageMode::Cxx,
14151414
interopContext, CFunctionSignatureTypePrinterModifierDelegate(),
1416-
emittedModule, declPrinter, FunctionSignatureTypeUse::ReturnType);
1417-
typePrinter.visit(objectType, optKind, /* isInOutParam */ false);
1418-
1419-
types.insert({type, typeStr});
1415+
emittedModule, declPrinter,
1416+
NeedsReturnTypes ? FunctionSignatureTypeUse::ReturnType
1417+
: FunctionSignatureTypeUse::ParamType);
1418+
auto support =
1419+
typePrinter.visit(objectType, optKind, /* isInOutParam */ false);
1420+
(void)support;
1421+
assert(!support.isUnsupported());
1422+
types.insert({type, typeOS.str()});
1423+
1424+
std::string typeRefStr;
1425+
llvm::raw_string_ostream typeRefOS(typeRefStr);
1426+
CFunctionSignatureTypePrinter typeRefPrinter(
1427+
typeRefOS, cPrologueOS, typeMapping, OutputLanguageMode::Cxx,
1428+
interopContext, CFunctionSignatureTypePrinterModifierDelegate(),
1429+
emittedModule, declPrinter, FunctionSignatureTypeUse::TypeReference);
1430+
typeRefPrinter.visit(objectType, optKind, /* isInOutParam */ false);
1431+
typeRefs.insert({type, typeRefOS.str()});
14201432
}
14211433

14221434
retTypeAndNamePrinter(types);
@@ -1430,6 +1442,6 @@ void DeclAndTypeClangFunctionPrinter::printCustomCxxFunction(
14301442
outOfLineOS << " const";
14311443
}
14321444
outOfLineOS << " {\n";
1433-
bodyPrinter(types);
1445+
bodyPrinter(typeRefs);
14341446
outOfLineOS << "}\n";
14351447
}

lib/PrintAsClang/PrintClangFunction.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,11 @@ class DeclAndTypeClangFunctionPrinter {
147147

148148
/// Print generated C++ helper function
149149
void printCustomCxxFunction(const SmallVector<Type> &neededTypes,
150+
bool NeedsReturnTypes,
150151
PrinterTy retTypeAndNamePrinter,
151152
PrinterTy paramPrinter, bool isConstFunc,
152-
PrinterTy bodyPrinter, ModuleDecl *emittedModule,
153+
PrinterTy bodyPrinter, ValueDecl *valueDecl,
154+
ModuleDecl *emittedModule,
153155
raw_ostream &outOfLineOS);
154156

155157
private:

lib/PrintAsClang/PrintClangValueType.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,6 @@ void ClangValueTypePrinter::printValueTypeDecl(
342342
os << " return enumVWTable->getEnumTag(_getOpaquePointer(), "
343343
"metadata._0);\n";
344344
os << " }\n";
345-
346-
for (const auto &pair : interopContext.getIrABIDetails().getEnumTagMapping(
347-
cast<EnumDecl>(typeDecl))) {
348-
os << " using _impl_" << pair.first->getNameStr() << " = decltype(";
349-
ClangSyntaxPrinter(os).printIdentifier(pair.first->getNameStr());
350-
os << ");\n";
351-
}
352345
}
353346
// Print out the storage for the value type.
354347
os << " ";

lib/PrintAsClang/_SwiftCxxInteroperability.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ class _impl_RefCountedClass {
131131
static inline void *_Nonnull &getOpaquePointerRef(RefCountedClass &object) {
132132
return object._opaquePointer;
133133
}
134+
static inline void *_Nonnull copyOpaquePointer(
135+
const RefCountedClass &object) {
136+
swift_retain(object._opaquePointer);
137+
return object._opaquePointer;
138+
}
134139
};
135140

136141
} // namespace _impl

test/Interop/SwiftToCxx/enums/resilient-enum-in-cxx.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public enum Empty {
4646
// CHECK: enum class cases {
4747
// CHECK-NEXT: unknownDefault
4848
// CHECK-NEXT: };
49-
// CHECK: inline const static struct { // impl struct for case unknownDefault
49+
// CHECK: inline const static struct _impl_unknownDefault { // impl struct for case unknownDefault
5050
// CHECK-NEXT: inline constexpr operator cases() const {
5151
// CHECK-NEXT: return cases::unknownDefault;
5252
// CHECK-NEXT: }
@@ -70,7 +70,7 @@ public enum Empty {
7070
// NEW_CASE-NEXT: b,
7171
// CHECK-NEXT: unknownDefault
7272
// CHECK-NEXT: }
73-
// CHECK: inline const static struct { // impl struct for case unknownDefault
73+
// CHECK: inline const static struct _impl_unknownDefault { // impl struct for case unknownDefault
7474
// CHECK-NEXT: inline constexpr operator cases() const {
7575
// CHECK-NEXT: return cases::unknownDefault;
7676
// CHECK-NEXT: }

0 commit comments

Comments
 (0)