Skip to content

[cxx-interop] Synthesize conformances to CxxSequence #60332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ PROTOCOL(DistributedTargetInvocationDecoder)
PROTOCOL(DistributedTargetInvocationResultHandler)

// C++ Standard Library Overlay:
PROTOCOL(CxxSequence)
PROTOCOL(UnsafeCxxInputIterator)

PROTOCOL(AsyncSequence)
Expand Down
1 change: 1 addition & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
M = getLoadedModule(Id_Distributed);
break;
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::UnsafeCxxInputIterator:
M = getLoadedModule(Id_Cxx);
break;
Expand Down
77 changes: 77 additions & 0 deletions lib/ClangImporter/ClangDerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,80 @@ void swift::conformToCxxIteratorIfNeeded(
impl.addSynthesizedProtocolAttrs(decl,
{KnownProtocolKind::UnsafeCxxInputIterator});
}

void swift::conformToCxxSequenceIfNeeded(
ClangImporter::Implementation &impl, NominalTypeDecl *decl,
const clang::CXXRecordDecl *clangDecl) {
PrettyStackTraceDecl trace("conforming to CxxSequence", decl);

assert(decl);
assert(clangDecl);
ASTContext &ctx = decl->getASTContext();

ProtocolDecl *cxxIteratorProto =
ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator);
ProtocolDecl *cxxSequenceProto =
ctx.getProtocol(KnownProtocolKind::CxxSequence);
// If the Cxx module is missing, or does not include one of the necessary
// protocols, bail.
if (!cxxIteratorProto || !cxxSequenceProto)
return;

// Check if present: `mutating func __beginUnsafe() -> RawIterator`
auto beginId = ctx.getIdentifier("__beginUnsafe");
auto begins = lookupDirectWithoutExtensions(decl, beginId);
if (begins.size() != 1)
return;
auto begin = dyn_cast<FuncDecl>(begins.front());
if (!begin)
return;
auto rawIteratorTy = begin->getResultInterfaceType();

// Check if present: `mutating func __endUnsafe() -> RawIterator`
auto endId = ctx.getIdentifier("__endUnsafe");
auto ends = lookupDirectWithoutExtensions(decl, endId);
if (ends.size() != 1)
return;
auto end = dyn_cast<FuncDecl>(ends.front());
if (!end)
return;

// Check if `__beginUnsafe` and `__endUnsafe` have the same return type.
auto endTy = end->getResultInterfaceType();
if (!endTy || endTy->getCanonicalType() != rawIteratorTy->getCanonicalType())
return;

// Check if RawIterator conforms to UnsafeCxxInputIterator.
auto rawIteratorConformanceRef = decl->getModuleContext()->lookupConformance(
rawIteratorTy, cxxIteratorProto);
if (!rawIteratorConformanceRef.isConcrete())
return;
auto rawIteratorConformance = rawIteratorConformanceRef.getConcrete();
auto pointeeDecl =
cxxIteratorProto->getAssociatedType(ctx.getIdentifier("Pointee"));
assert(pointeeDecl &&
"UnsafeCxxInputIterator must have a Pointee associated type");
auto pointeeTy = rawIteratorConformance->getTypeWitness(pointeeDecl);
assert(pointeeTy && "valid conformance must have a Pointee witness");

// Take the default definition of `Iterator` from CxxSequence protocol. This
// type is currently `CxxIterator<Self>`.
auto iteratorDecl = cxxSequenceProto->getAssociatedType(ctx.Id_Iterator);
auto iteratorTy = iteratorDecl->getDefaultDefinitionType();
// Substitute generic `Self` parameter.
auto cxxSequenceSelfTy = cxxSequenceProto->getSelfInterfaceType();
auto declSelfTy = decl->getDeclaredInterfaceType();
iteratorTy = iteratorTy.subst(
[&](SubstitutableType *dependentType) {
if (dependentType->isEqual(cxxSequenceSelfTy))
return declSelfTy;
return Type(dependentType);
},
LookUpConformanceInModule(decl->getModuleContext()));

impl.addSynthesizedTypealias(decl, ctx.Id_Element, pointeeTy);
impl.addSynthesizedTypealias(decl, ctx.Id_Iterator, iteratorTy);
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"),
rawIteratorTy);
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSequence});
}
8 changes: 7 additions & 1 deletion lib/ClangImporter/ClangDerivedConformances.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ namespace swift {
bool isIterator(const clang::CXXRecordDecl *clangDecl);

/// If the decl is a C++ input iterator, synthesize a conformance to the
/// UnsafeCxxInputIterator protocol, which is defined in the std overlay.
/// UnsafeCxxInputIterator protocol, which is defined in the Cxx module.
void conformToCxxIteratorIfNeeded(ClangImporter::Implementation &impl,
NominalTypeDecl *decl,
const clang::CXXRecordDecl *clangDecl);

/// If the decl is a C++ sequence, synthesize a conformance to the CxxSequence
/// protocol, which is defined in the Cxx module.
void conformToCxxSequenceIfNeeded(ClangImporter::Implementation &impl,
NominalTypeDecl *decl,
const clang::CXXRecordDecl *clangDecl);

} // namespace swift

#endif // SWIFT_CLANG_DERIVED_CONFORMANCES_H
1 change: 1 addition & 0 deletions lib/ClangImporter/ImportDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,7 @@ namespace {
if (clangModule && requiresCPlusPlus(clangModule)) {
if (auto structDecl = dyn_cast_or_null<NominalTypeDecl>(result)) {
conformToCxxIteratorIfNeeded(Impl, structDecl, decl);
conformToCxxSequenceIfNeeded(Impl, structDecl, decl);
}
}

Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5808,6 +5808,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::DistributedTargetInvocationEncoder:
case KnownProtocolKind::DistributedTargetInvocationDecoder:
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::SerialExecutor:
case KnownProtocolKind::Sendable:
Expand Down
49 changes: 49 additions & 0 deletions test/Interop/Cxx/stdlib/overlay/Inputs/custom-sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,53 @@ struct SimpleEmptySequence {
const int *end() const { return nullptr; }
};

struct HasMutatingBeginEnd {
ConstIterator begin() { return ConstIterator(1); }
ConstIterator end() { return ConstIterator(5); }
};

// TODO: this should conform to CxxSequence.
struct __attribute__((swift_attr("import_reference"),
swift_attr("retain:immortal"),
swift_attr("release:immortal"))) ImmortalSequence {
ConstIterator begin() { return ConstIterator(1); }
ConstIterator end() { return ConstIterator(5); }
};

// MARK: Types that are not actually sequences

struct HasNoBeginMethod {
ConstIterator end() const { return ConstIterator(1); }
};

struct HasNoEndMethod {
ConstIterator begin() const { return ConstIterator(1); }
};

struct HasBeginEndTypeMismatch {
ConstIterator begin() const { return ConstIterator(1); }
ConstIteratorOutOfLineEq end() const { return ConstIteratorOutOfLineEq(3); }
};

struct HasBeginEndReturnNonIterators {
struct NotIterator {};

NotIterator begin() const { return NotIterator(); }
NotIterator end() const { return NotIterator(); }
};

// TODO: this should not be conformed to CxxSequence, because
// `const ConstIterator &` is imported as `UnsafePointer<ConstIterator>`, and
// calling `successor()` is not actually going to call
// `ConstIterator::operator++()`. It will increment the address instead.
struct HasBeginEndReturnRef {
private:
ConstIterator b = ConstIterator(1);
ConstIterator e = ConstIterator(5);

public:
const ConstIterator &begin() const { return b; }
const ConstIterator &end() const { return e; }
};

#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_SEQUENCE_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomSequence -source-filename=x -I %S/Inputs -enable-experimental-cxx-interop -module-cache-path %t | %FileCheck %s

// CHECK: import Cxx

// CHECK: struct SimpleSequence : CxxSequence {
// CHECK: typealias Element = ConstIterator.Pointee
// CHECK: typealias Iterator = CxxIterator<SimpleSequence>
// CHECK: typealias RawIterator = ConstIterator
// CHECK: }

// CHECK: struct SimpleSequenceWithOutOfLineEqualEqual : CxxSequence {
// CHECK: typealias Element = ConstIteratorOutOfLineEq.Pointee
// CHECK: typealias Iterator = CxxIterator<SimpleSequenceWithOutOfLineEqualEqual>
// CHECK: typealias RawIterator = ConstIteratorOutOfLineEq
// CHECK: }

// CHECK: struct SimpleArrayWrapper : CxxSequence {
// CHECK: typealias Element = UnsafePointer<Int32>.Pointee
// CHECK: typealias Iterator = CxxIterator<SimpleArrayWrapper>
// CHECK: typealias RawIterator = UnsafePointer<Int32>
// CHECK: }

// CHECK: struct SimpleArrayWrapperNullableIterators : CxxSequence {
// CHECK: typealias Element = Optional<UnsafePointer<Int32>>.Pointee
// CHECK: typealias Iterator = CxxIterator<SimpleArrayWrapperNullableIterators>
// CHECK: typealias RawIterator = UnsafePointer<Int32>?
// CHECK: }

// CHECK: struct SimpleEmptySequence : CxxSequence {
// CHECK: typealias Element = Optional<UnsafePointer<Int32>>.Pointee
// CHECK: typealias Iterator = CxxIterator<SimpleEmptySequence>
// CHECK: typealias RawIterator = UnsafePointer<Int32>?
// CHECK: }

// CHECK: struct HasMutatingBeginEnd : CxxSequence {
// CHECK: typealias Element = ConstIterator.Pointee
// CHECK: typealias Iterator = CxxIterator<HasMutatingBeginEnd>
// CHECK: typealias RawIterator = ConstIterator
// CHECK: }

// CHECK: struct HasNoBeginMethod {
// CHECK-NOT: typealias Element
// CHECK-NOT: typealias Iterator
// CHECK-NOT: typealias RawIterator
// CHECK: }
// CHECK: struct HasNoEndMethod {
// CHECK-NOT: typealias Element
// CHECK-NOT: typealias Iterator
// CHECK-NOT: typealias RawIterator
// CHECK: }
// CHECK: struct HasBeginEndTypeMismatch {
// CHECK-NOT: typealias Element
// CHECK-NOT: typealias Iterator
// CHECK-NOT: typealias RawIterator
// CHECK: }
// CHECK: struct HasBeginEndReturnNonIterators {
// CHECK-NOT: typealias Element
// CHECK-NOT: typealias Iterator
// CHECK-NOT: typealias RawIterator
// CHECK: }
24 changes: 14 additions & 10 deletions test/Interop/Cxx/stdlib/overlay/custom-sequence-typechecker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
import CustomSequence
import Cxx

// === SimpleSequence ===
// Conformance to UnsafeCxxInputIterator is synthesized.
extension SimpleSequence: CxxSequence {}

func checkSimpleSequence() {
let seq = SimpleSequence()
func checkIntSequence<S>(_ seq: S) where S: Sequence, S.Element == Int32 {
let contains = seq.contains(where: { $0 == 3 })
print(contains)

Expand All @@ -17,17 +12,26 @@ func checkSimpleSequence() {
}
}

// === SimpleSequence ===
// Conformance to UnsafeCxxInputIterator is synthesized.
// Conformance to CxxSequence is synthesized.
checkIntSequence(SimpleSequence())

// === SimpleSequenceWithOutOfLineEqualEqual ===
extension SimpleSequenceWithOutOfLineEqualEqual : CxxSequence {}
// Conformance to CxxSequence is synthesized.
checkIntSequence(SimpleSequenceWithOutOfLineEqualEqual())

// === SimpleArrayWrapper ===
// No UnsafeCxxInputIterator conformance required, since the iterators are actually UnsafePointers here.
extension SimpleArrayWrapper: CxxSequence {}
// Conformance to CxxSequence is synthesized.
checkIntSequence(SimpleArrayWrapper())

// === SimpleArrayWrapperNullableIterators ===
// No UnsafeCxxInputIterator conformance required, since the iterators are actually optional UnsafePointers here.
extension SimpleArrayWrapperNullableIterators: CxxSequence {}
// Conformance to CxxSequence is synthesized.
checkIntSequence(SimpleArrayWrapperNullableIterators())

// === SimpleEmptySequence ===
// No UnsafeCxxInputIterator conformance required, since the iterators are actually optional UnsafePointers here.
extension SimpleEmptySequence: CxxSequence {}
// Conformance to CxxSequence is synthesized.
checkIntSequence(SimpleEmptySequence())
5 changes: 0 additions & 5 deletions test/Interop/Cxx/stdlib/overlay/custom-sequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ import Cxx

var CxxSequenceTestSuite = TestSuite("CxxSequence")

extension SimpleSequence: CxxSequence {}

extension SimpleEmptySequence: CxxSequence {}


CxxSequenceTestSuite.test("SimpleSequence as Swift.Sequence") {
let seq = SimpleSequence()
let contains = seq.contains(where: { $0 == 3 })
Expand Down