Skip to content

[Interop][SwiftToCxx] Update current enum implementation for new enum design #60564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/CppInteroperability/UserGuide-CallingSwiftFromC++.md
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ enum value will abort the program.
A resilient Swift enumeration value could represent a case that's unknown to the client.
Swift forces the client to check if the value is `@uknown default` when switching over
the enumeration to account for that. C++ follows a similar principle,
by exposing an `unknown_default` case that can then be matched in a switch.
by exposing an `unknownDefault` case that can then be matched in a switch.

For example, given the following resilient enumeration:

Expand All @@ -620,14 +620,14 @@ void test(const DateFormatStyle &style) {
case DateFormatStyle::full:
...
break;
case DateFormatStyle::unknown_default: // just like Swift's @unknown default
case DateFormatStyle::unknownDefault: // just like Swift's @unknown default
// Some case value added in a future version of enum.
break;
}
}
```

The `unknown_default` case value is not a constructible case and you will get a compiler error if you try to construct it in C++.
The `unknownDefault` case value is not a constructible case and you will get a compiler error if you try to construct it in C++.

## Using Swift Class Types

Expand Down
180 changes: 107 additions & 73 deletions lib/PrintAsClang/DeclAndTypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,107 +414,71 @@ class DeclAndTypePrinter::Implementation
return p1.second.tag < p2.second.tag;
});

if (elementTagMapping.empty()) {
os << "\n";
return;
}

os << " enum class cases {\n";
for (const auto &pair : elementTagMapping) {
os << " ";
syntaxPrinter.printIdentifier(pair.first->getNameStr());
os << ",\n";
}
os << " };\n"; // enum class cases' closing bracket

// Printing operator cases()
os << " inline operator cases() const {\n";
if (ED->isResilient()) {
os << " auto tag = _getEnumTag();\n";
for (const auto &pair : elementTagMapping) {
os << " if (tag == " << cxx_synthesis::getCxxImplNamespaceName();
os << "::" << pair.second.globalVariableName << ") return cases::";
syntaxPrinter.printIdentifier(pair.first->getNameStr());
os << ";\n";
}
// TODO: change to Swift's fatalError when it's available in C++
os << " abort();\n";
} else { // non-resilient enum
os << " switch (_getEnumTag()) {\n";
for (const auto &pair : elementTagMapping) {
os << " case " << pair.second.tag << ": return cases::";
syntaxPrinter.printIdentifier(pair.first->getNameStr());
os << ";\n";
}
// TODO: change to Swift's fatalError when it's available in C++
os << " default: abort();\n";
os << " }\n"; // switch's closing bracket
}
os << " }\n"; // operator cases()'s closing bracket

os << '\n';
os << " enum class cases {";
llvm::interleave(
elementTagMapping, os,
[&](const auto &pair) {
os << "\n ";
syntaxPrinter.printIdentifier(pair.first->getNameStr());
},
",");
// TODO: allow custom name for this special case
auto resilientUnknownDefaultCaseName = "unknownDefault";
if (ED->isResilient()) {
os << " inline bool inResilientUnknownCase() const {\n";
os << " auto tag = _getEnumTag();\n";
os << " return";
llvm::interleave(
elementTagMapping, os,
[&](const auto &pair) {
os << "\n tag != " << cxx_synthesis::getCxxImplNamespaceName()
<< "::" << pair.second.globalVariableName;
},
" &&");
os << ";\n";
os << " }\n";
os << ",\n " << resilientUnknownDefaultCaseName;
}
os << "\n };\n\n"; // enum class cases' closing bracket

// Printing case-related functions
// Printing struct, is, and get functions for each case
DeclAndTypeClangFunctionPrinter clangFuncPrinter(
os, owningPrinter.prologueOS, owningPrinter.typeMapping,
owningPrinter.interopContext);

for (const auto &pair : elementTagMapping) {
auto printIsFunction = [&](StringRef caseName, EnumDecl *ED) {
os << " inline bool is";
auto name = pair.first->getNameStr().str();
std::string name;
llvm::raw_string_ostream nameStream(name);
ClangSyntaxPrinter(nameStream).printIdentifier(caseName);
name[0] = std::toupper(name[0]);
os << name << "() const {\n";
os << " return _getEnumTag() == ";
if (ED->isResilient()) {
os << cxx_synthesis::getCxxImplNamespaceName()
<< "::" << pair.second.globalVariableName;
} else {
os << pair.second.tag;
}
os << ";\n }\n";

if (!pair.first->hasAssociatedValues()) {
continue;
}
os << " return *this == ";
syntaxPrinter.printBaseName(ED);
os << "::";
syntaxPrinter.printIdentifier(caseName);
os << ";\n";
os << " }\n";
};

auto associatedValueList = pair.first->getParameterList();
auto printGetFunction = [&](EnumElementDecl *elementDecl) {
auto associatedValueList = elementDecl->getParameterList();
// TODO: add tuple type support
if (associatedValueList->size() > 1) {
continue;
return;
}
auto firstType = associatedValueList->front()->getType();
auto firstTypeDecl = firstType->getNominalOrBoundGenericNominal();
OptionalTypeKind optKind;
std::tie(firstType, optKind) =
getObjectTypeAndOptionality(firstTypeDecl, firstType);

// FIXME: (tongjie) may have to forward declare return type
auto name = elementDecl->getNameStr().str();
name[0] = std::toupper(name[0]);

// FIXME: may have to forward declare return type
os << " inline ";
clangFuncPrinter.printClangFunctionReturnType(
firstType, optKind, firstTypeDecl->getModuleContext(),
owningPrinter.outputLang);
os << " get" << name << "() const {\n";
os << " if (!is" << name << "()) abort();\n";
os << " alignas(";
syntaxPrinter.printBaseName(ED);
syntaxPrinter.printBaseName(elementDecl->getParentEnum());
os << ") unsigned char buffer[sizeof(";
syntaxPrinter.printBaseName(ED);
syntaxPrinter.printBaseName(elementDecl->getParentEnum());
os << ")];\n";
os << " auto *thisCopy = new(buffer) ";
syntaxPrinter.printBaseName(ED);
syntaxPrinter.printBaseName(elementDecl->getParentEnum());
os << "(*this);\n";
os << " char * _Nonnull payloadFromDestruction = "
"thisCopy->_destructiveProjectEnumData();\n";
Expand All @@ -531,7 +495,8 @@ class DeclAndTypePrinter::Implementation
} else {
os << " return ";
syntaxPrinter.printModuleNamespaceQualifiersIfNeeded(
firstTypeDecl->getModuleContext(), ED->getModuleContext());
firstTypeDecl->getModuleContext(),
elementDecl->getParentEnum()->getModuleContext());
os << cxx_synthesis::getCxxImplNamespaceName();
os << "::";
ClangValueTypePrinter::printCxxImplClassName(os, firstTypeDecl);
Expand All @@ -542,8 +507,77 @@ class DeclAndTypePrinter::Implementation
os << "::initializeWithTake(result, payloadFromDestruction);\n";
os << " });\n";
}
os << " }\n";
os << " }\n"; // closing bracket of get function
};

auto printStruct = [&](StringRef caseName, EnumElementDecl *elementDecl) {
os << " static struct { // impl struct for case " << caseName << '\n';
os << " inline constexpr operator cases() const {\n";
os << " return cases::";
syntaxPrinter.printIdentifier(caseName);
os << ";\n";
os << " }\n";
if (elementDecl != nullptr) {
os << " inline ";
syntaxPrinter.printBaseName(elementDecl->getParentEnum());
os << " operator()(";
// TODO: implement parameter for associated value
os << ") const {\n";
// TODO: print _make for now; need to print actual code making an enum
os << " return ";
syntaxPrinter.printBaseName(elementDecl->getParentEnum());
os << "::_make();\n";
os << " }\n";
}
os << " } ";
syntaxPrinter.printIdentifier(caseName);
os << ";\n";
};

for (const auto &pair : elementTagMapping) {
// Printing struct
printStruct(pair.first->getNameStr(), pair.first);
// Printing `is` function
printIsFunction(pair.first->getNameStr(), ED);
if (pair.first->hasAssociatedValues()) {
// Printing `get` function
printGetFunction(pair.first);
}
os << '\n';
}

if (ED->isResilient()) {
// Printing struct for unknownDefault
printStruct(resilientUnknownDefaultCaseName, /* elementDecl */ nullptr);
// Printing isUnknownDefault
printIsFunction(resilientUnknownDefaultCaseName, ED);
os << '\n';
}
os << '\n';

// Printing operator cases()
os << " inline operator cases() const {\n";
if (ED->isResilient()) {
os << " auto tag = _getEnumTag();\n";
for (const auto &pair : elementTagMapping) {
os << " if (tag == " << cxx_synthesis::getCxxImplNamespaceName();
os << "::" << pair.second.globalVariableName << ") return cases::";
syntaxPrinter.printIdentifier(pair.first->getNameStr());
os << ";\n";
}
os << " return cases::" << resilientUnknownDefaultCaseName << ";\n";
} else { // non-resilient enum
os << " switch (_getEnumTag()) {\n";
for (const auto &pair : elementTagMapping) {
os << " case " << pair.second.tag << ": return cases::";
syntaxPrinter.printIdentifier(pair.first->getNameStr());
os << ";\n";
}
// TODO: change to Swift's fatalError when it's available in C++
os << " default: abort();\n";
os << " }\n"; // switch's closing bracket
}
os << " }\n"; // operator cases()'s closing bracket
os << "\n";
});
os << outOfLineDefinitions;
Expand Down
31 changes: 30 additions & 1 deletion lib/PrintAsClang/PrintClangValueType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,36 @@ void ClangValueTypePrinter::printValueTypeDecl(
os << " friend class " << cxx_synthesis::getCxxImplNamespaceName() << "::";
printCxxImplClassName(os, typeDecl);
os << ";\n";
os << "};\n\n";
os << "};\n";
// Print the definition of enum static struct data memebers
if (isa<EnumDecl>(typeDecl)) {
auto tagMapping = interopContext.getIrABIDetails().getEnumTagMapping(
cast<EnumDecl>(typeDecl));
for (const auto &pair : tagMapping) {
os << "decltype(";
printer.printBaseName(typeDecl);
os << "::";
printer.printIdentifier(pair.first->getNameStr());
os << ") ";
printer.printBaseName(typeDecl);
os << "::";
printer.printIdentifier(pair.first->getNameStr());
os << ";\n";
}
if (isOpaqueLayout) {
os << "decltype(";
printer.printBaseName(typeDecl);
// TODO: allow custom name for this special case
os << "::";
printer.printIdentifier("unknownDefault");
os << ") ";
printer.printBaseName(typeDecl);
os << "::";
printer.printIdentifier("unknownDefault");
os << ";\n";
}
}
os << '\n';

const auto *moduleContext = typeDecl->getModuleContext();
// Print out the "hidden" _impl class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,38 @@
#include <iostream>
#include "enums.h"

using namespace Enums;

void useFooInSwitch(const Foo& f) {
switch (f) {
case Foo::a:
std::cout << "Foo::a\n";
break;;
case Foo::unknownDefault:
std::cout << "Foo::unknownDefault\n";
break;
}
}

int main() {
using namespace Enums;
auto f1 = makeFoo(10);
auto f2 = makeFoo(-10);

printFoo(f1);
printFoo(f2);

assert(!f2.inResilientUnknownCase());
if (f1.inResilientUnknownCase()) {
assert(!f2.isUnknownDefault());
if (f1.isUnknownDefault()) {
std::cout << "f1.inResilientUnknownCase()\n";
assert(!f1.isA());
} else {
assert(f1.isA());
assert(f1.getA() == 10.0);
}

useFooInSwitch(f1);
useFooInSwitch(f2);

return 0;
}

Expand All @@ -43,3 +58,7 @@ int main() {
// CHECK-NEXT: a(-10.0)

// NEW_CASE: f1.inResilientUnknownCase()

// NEW_CASE: Foo::unknownDefault
// OLD_CASE: Foo::a
// CHECK-NEXT: Foo::a
28 changes: 14 additions & 14 deletions test/Interop/SwiftToCxx/enums/resilient-enum-in-cxx.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,22 @@ public func printFoo(_ x: Foo) {
// CHECK-EMPTY:
// CHECK-NEXT: class Foo final {
// CHECK-NEXT: public:
// CHECK: enum class cases {
// CHECK-NEXT: a,
// NEW_CASE-NEXT: b,
// CHECK-NEXT: unknownDefault
// CHECK-NEXT: }
// CHECK: static struct { // impl struct for case unknownDefault
// CHECK-NEXT: constexpr operator cases() const {
// CHECK-NEXT: return cases::unknownDefault;
// CHECK-NEXT: }
// CHECK-NEXT: } unknownDefault;
// CHECK-NEXT: inline bool isUnknownDefault() const {
// CHECK-NEXT: return *this == Foo::unknownDefault;
// CHECK-NEXT: }
// CHECK: inline operator cases() const {
// CHECK-NEXT: auto tag = _getEnumTag();
// CHECK-NEXT: if (tag == _impl::$s5Enums3FooO1ayACSdcACmFWC) return cases::a;
// NEW_CASE-NEXT: if (tag == _impl::$s5Enums3FooO1byACSicACmFWC) return cases::b;
// CHECK-NEXT: abort();
// CHECK-NEXT: }
// CHECK-NEXT: inline bool inResilientUnknownCase() const {
// CHECK-NEXT: auto tag = _getEnumTag();
// CHECK-NEXT: return
// OLD_CASE-NEXT: tag != _impl::$s5Enums3FooO1ayACSdcACmFWC;
// NEW_CASE-NEXT: tag != _impl::$s5Enums3FooO1ayACSdcACmFWC &&
// NEW_CASE-NEXT: tag != _impl::$s5Enums3FooO1byACSicACmFWC;
// CHECK-NEXT: }
// CHECK-NEXT: inline bool isA() const {
// CHECK-NEXT: return _getEnumTag() == _impl::$s5Enums3FooO1ayACSdcACmFWC;
// CHECK-NEXT: return cases::unknownDefault;
// CHECK-NEXT: }
// NEW_CASE: inline bool isB() const {
// NEW_CASE-NEXT: return _getEnumTag() == _impl::$s5Enums3FooO1byACSicACmFWC;
// NEW_CASE-NEXT: }
Loading