Skip to content

Commit eac5418

Browse files
committed
[cxx-interop] Synthesize conformances to CxxSequence
This makes ClangImporter automatically conform C++ sequence types to `Cxx.CxxSequence` protocol. We consider a C++ type to be a sequence type if it defines `begin()` & `end()` methods that return iterators of the same type which conforms to `UnsafeCxxInputIterator`.
1 parent 379fc1f commit eac5418

File tree

10 files changed

+213
-17
lines changed

10 files changed

+213
-17
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ PROTOCOL(DistributedTargetInvocationDecoder)
105105
PROTOCOL(DistributedTargetInvocationResultHandler)
106106

107107
// C++ Standard Library Overlay:
108+
PROTOCOL(CxxSequence)
108109
PROTOCOL(UnsafeCxxInputIterator)
109110

110111
PROTOCOL(AsyncSequence)

lib/AST/ASTContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
10511051
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
10521052
M = getLoadedModule(Id_Distributed);
10531053
break;
1054+
case KnownProtocolKind::CxxSequence:
10541055
case KnownProtocolKind::UnsafeCxxInputIterator:
10551056
M = getLoadedModule(Id_Cxx);
10561057
break;

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "swift/AST/NameLookup.h"
1515
#include "swift/AST/ParameterList.h"
1616
#include "swift/AST/PrettyStackTrace.h"
17+
#include "swift/AST/ProtocolConformance.h"
1718

1819
using namespace swift;
1920

@@ -160,3 +161,83 @@ void swift::conformToCxxIteratorIfNeeded(
160161
impl.addSynthesizedProtocolAttrs(decl,
161162
{KnownProtocolKind::UnsafeCxxInputIterator});
162163
}
164+
165+
void swift::conformToCxxSequenceIfNeeded(
166+
ClangImporter::Implementation &impl, NominalTypeDecl *decl,
167+
const clang::CXXRecordDecl *clangDecl) {
168+
PrettyStackTraceDecl trace("conforming to CxxSequence", decl);
169+
170+
assert(decl);
171+
assert(clangDecl);
172+
ASTContext &ctx = decl->getASTContext();
173+
174+
ProtocolDecl *cxxIteratorProto =
175+
ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator);
176+
ProtocolDecl *cxxSequenceProto =
177+
ctx.getProtocol(KnownProtocolKind::CxxSequence);
178+
// If the Cxx module is missing, or does not include one of the necessary
179+
// protocols, bail.
180+
if (!cxxIteratorProto || !cxxSequenceProto)
181+
return;
182+
183+
// Check if present: `mutating func __beginUnsafe() -> RawIterator`
184+
auto beginId = ctx.getIdentifier("__beginUnsafe");
185+
auto begins = decl->lookupDirect(beginId);
186+
if (begins.size() != 1)
187+
return;
188+
auto begin = dyn_cast<FuncDecl>(begins.front());
189+
if (!begin)
190+
return;
191+
auto rawIteratorTy = begin->getResultInterfaceType();
192+
193+
// Check if present: `mutating func __endUnsafe() -> RawIterator`
194+
auto endId = ctx.getIdentifier("__endUnsafe");
195+
auto ends = decl->lookupDirect(endId);
196+
if (ends.size() != 1)
197+
return;
198+
auto end = dyn_cast<FuncDecl>(ends.front());
199+
if (!end)
200+
return;
201+
202+
// Check if `__beginUnsafe` and `__endUnsafe` have the same return type.
203+
auto endTy = end->getResultInterfaceType();
204+
if (!endTy || endTy->getCanonicalType() != rawIteratorTy->getCanonicalType())
205+
return;
206+
207+
// Check if RawIterator conforms to UnsafeCxxInputIterator.
208+
auto rawIteratorConformanceRef = decl->getModuleContext()->lookupConformance(
209+
rawIteratorTy, cxxIteratorProto);
210+
if (!rawIteratorConformanceRef.isConcrete())
211+
return;
212+
auto rawIteratorConformance = rawIteratorConformanceRef.getConcrete();
213+
auto pointeeDecl =
214+
cxxIteratorProto->getAssociatedType(ctx.getIdentifier("Pointee"));
215+
assert(pointeeDecl &&
216+
"UnsafeCxxInputIterator must have a Pointee associated type");
217+
auto pointeeTy = rawIteratorConformance->getTypeWitness(pointeeDecl);
218+
assert(pointeeTy && "valid conformance must have a Pointee witness");
219+
220+
// Take the default definition of `Iterator` from CxxSequence protocol. This
221+
// type is currently `CxxIterator<Self>`.
222+
auto iteratorDecl = cxxSequenceProto->getAssociatedType(ctx.Id_Iterator);
223+
auto iteratorTy = iteratorDecl->getDefaultDefinitionType();
224+
// Substitute generic `Self` parameter.
225+
auto cxxSequenceSelfTy = cxxSequenceProto->getSelfInterfaceType();
226+
auto declSelfTy = decl->getDeclaredInterfaceType();
227+
// TODO: do we need this special case for foreign reference types?
228+
if (declSelfTy->getClassOrBoundGenericClass())
229+
declSelfTy = DynamicSelfType::get(declSelfTy, decl->getASTContext());
230+
iteratorTy = iteratorTy.subst(
231+
[&](SubstitutableType *dependentType) {
232+
if (dependentType->isEqual(cxxSequenceSelfTy))
233+
return declSelfTy;
234+
return Type(dependentType);
235+
},
236+
LookUpConformanceInModule(decl->getModuleContext()));
237+
238+
impl.addSynthesizedTypealias(decl, ctx.Id_Element, pointeeTy);
239+
impl.addSynthesizedTypealias(decl, ctx.Id_Iterator, iteratorTy);
240+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"),
241+
rawIteratorTy);
242+
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSequence});
243+
}

lib/ClangImporter/ClangDerivedConformances.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,17 @@ namespace swift {
2121
bool isIterator(const clang::CXXRecordDecl *clangDecl);
2222

2323
/// If the decl is a C++ input iterator, synthesize a conformance to the
24-
/// UnsafeCxxInputIterator protocol, which is defined in the std overlay.
24+
/// UnsafeCxxInputIterator protocol, which is defined in the Cxx module.
2525
void conformToCxxIteratorIfNeeded(ClangImporter::Implementation &impl,
2626
NominalTypeDecl *decl,
2727
const clang::CXXRecordDecl *clangDecl);
2828

29+
/// If the decl is a C++ sequence, synthesize a conformance to the CxxSequence
30+
/// protocol, which is defined in the Cxx module.
31+
void conformToCxxSequenceIfNeeded(ClangImporter::Implementation &impl,
32+
NominalTypeDecl *decl,
33+
const clang::CXXRecordDecl *clangDecl);
34+
2935
} // namespace swift
3036

3137
#endif // SWIFT_CLANG_DERIVED_CONFORMANCES_H

lib/ClangImporter/ImportDecl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,6 +2612,7 @@ namespace {
26122612
if (clangModule && requiresCPlusPlus(clangModule)) {
26132613
if (auto structDecl = dyn_cast_or_null<NominalTypeDecl>(result)) {
26142614
conformToCxxIteratorIfNeeded(Impl, structDecl, decl);
2615+
conformToCxxSequenceIfNeeded(Impl, structDecl, decl);
26152616
}
26162617
}
26172618

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5790,6 +5790,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
57905790
case KnownProtocolKind::DistributedTargetInvocationEncoder:
57915791
case KnownProtocolKind::DistributedTargetInvocationDecoder:
57925792
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
5793+
case KnownProtocolKind::CxxSequence:
57935794
case KnownProtocolKind::UnsafeCxxInputIterator:
57945795
case KnownProtocolKind::SerialExecutor:
57955796
case KnownProtocolKind::Sendable:

test/Interop/Cxx/stdlib/overlay/Inputs/custom-sequence.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,51 @@ struct SimpleEmptySequence {
3737
const int *end() const { return nullptr; }
3838
};
3939

40+
struct HasMutatingBeginEnd {
41+
ConstIterator begin() { return ConstIterator(1); }
42+
ConstIterator end() { return ConstIterator(5); }
43+
};
44+
45+
// TODO: this should conform to CxxSequence.
46+
struct __attribute__((swift_attr("import_reference"))) ImmortalSequence {
47+
ConstIterator begin() { return ConstIterator(1); }
48+
ConstIterator end() { return ConstIterator(5); }
49+
};
50+
51+
// MARK: Types that are not actually sequences
52+
53+
struct HasNoBeginMethod {
54+
ConstIterator end() const { return ConstIterator(1); }
55+
};
56+
57+
struct HasNoEndMethod {
58+
ConstIterator begin() const { return ConstIterator(1); }
59+
};
60+
61+
struct HasBeginEndTypeMismatch {
62+
ConstIterator begin() const { return ConstIterator(1); }
63+
ConstIteratorOutOfLineEq end() const { return ConstIteratorOutOfLineEq(3); }
64+
};
65+
66+
struct HasBeginEndReturnNonIterators {
67+
struct NotIterator {};
68+
69+
NotIterator begin() const { return NotIterator(); }
70+
NotIterator end() const { return NotIterator(); }
71+
};
72+
73+
// TODO: this should not be conformed to CxxSequence, because
74+
// `const ConstIterator &` is imported as `UnsafePointer<ConstIterator>`, and
75+
// calling `successor()` is not actually going to call
76+
// `ConstIterator::operator++()`. It will increment the address instead.
77+
struct HasBeginEndReturnRef {
78+
private:
79+
ConstIterator b = ConstIterator(1);
80+
ConstIterator e = ConstIterator(5);
81+
82+
public:
83+
const ConstIterator &begin() const { return b; }
84+
const ConstIterator &end() const { return e; }
85+
};
86+
4087
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_SEQUENCE_H
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomSequence -source-filename=x -I %S/Inputs -enable-experimental-cxx-interop -module-cache-path %t | %FileCheck %s
2+
3+
// CHECK: import Cxx
4+
5+
// CHECK: struct SimpleSequence : CxxSequence {
6+
// CHECK: typealias Element = ConstIterator.Pointee
7+
// CHECK: typealias Iterator = CxxIterator<SimpleSequence>
8+
// CHECK: typealias RawIterator = ConstIterator
9+
// CHECK: }
10+
11+
// CHECK: struct SimpleSequenceWithOutOfLineEqualEqual : CxxSequence {
12+
// CHECK: typealias Element = ConstIteratorOutOfLineEq.Pointee
13+
// CHECK: typealias Iterator = CxxIterator<SimpleSequenceWithOutOfLineEqualEqual>
14+
// CHECK: typealias RawIterator = ConstIteratorOutOfLineEq
15+
// CHECK: }
16+
17+
// CHECK: struct SimpleArrayWrapper : CxxSequence {
18+
// CHECK: typealias Element = UnsafePointer<Int32>.Pointee
19+
// CHECK: typealias Iterator = CxxIterator<SimpleArrayWrapper>
20+
// CHECK: typealias RawIterator = UnsafePointer<Int32>
21+
// CHECK: }
22+
23+
// CHECK: struct SimpleArrayWrapperNullableIterators : CxxSequence {
24+
// CHECK: typealias Element = Optional<UnsafePointer<Int32>>.Pointee
25+
// CHECK: typealias Iterator = CxxIterator<SimpleArrayWrapperNullableIterators>
26+
// CHECK: typealias RawIterator = UnsafePointer<Int32>?
27+
// CHECK: }
28+
29+
// CHECK: struct SimpleEmptySequence : CxxSequence {
30+
// CHECK: typealias Element = Optional<UnsafePointer<Int32>>.Pointee
31+
// CHECK: typealias Iterator = CxxIterator<SimpleEmptySequence>
32+
// CHECK: typealias RawIterator = UnsafePointer<Int32>?
33+
// CHECK: }
34+
35+
// CHECK: struct HasMutatingBeginEnd : CxxSequence {
36+
// CHECK: typealias Element = ConstIterator.Pointee
37+
// CHECK: typealias Iterator = CxxIterator<HasMutatingBeginEnd>
38+
// CHECK: typealias RawIterator = ConstIterator
39+
// CHECK: }
40+
41+
// CHECK: struct HasNoBeginMethod {
42+
// CHECK-NOT: typealias Element
43+
// CHECK-NOT: typealias Iterator
44+
// CHECK-NOT: typealias RawIterator
45+
// CHECK: }
46+
// CHECK: struct HasNoEndMethod {
47+
// CHECK-NOT: typealias Element
48+
// CHECK-NOT: typealias Iterator
49+
// CHECK-NOT: typealias RawIterator
50+
// CHECK: }
51+
// CHECK: struct HasBeginEndTypeMismatch {
52+
// CHECK-NOT: typealias Element
53+
// CHECK-NOT: typealias Iterator
54+
// CHECK-NOT: typealias RawIterator
55+
// CHECK: }
56+
// CHECK: struct HasBeginEndReturnNonIterators {
57+
// CHECK-NOT: typealias Element
58+
// CHECK-NOT: typealias Iterator
59+
// CHECK-NOT: typealias RawIterator
60+
// CHECK: }

test/Interop/Cxx/stdlib/overlay/custom-sequence-typechecker.swift

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,7 @@
33
import CustomSequence
44
import Cxx
55

6-
// === SimpleSequence ===
7-
// Conformance to UnsafeCxxInputIterator is synthesized.
8-
extension SimpleSequence: CxxSequence {}
9-
10-
func checkSimpleSequence() {
11-
let seq = SimpleSequence()
6+
func checkIntSequence<S>(_ seq: S) where S: Sequence, S.Element == Int32 {
127
let contains = seq.contains(where: { $0 == 3 })
138
print(contains)
149

@@ -17,17 +12,26 @@ func checkSimpleSequence() {
1712
}
1813
}
1914

15+
// === SimpleSequence ===
16+
// Conformance to UnsafeCxxInputIterator is synthesized.
17+
// Conformance to CxxSequence is synthesized.
18+
checkIntSequence(SimpleSequence())
19+
2020
// === SimpleSequenceWithOutOfLineEqualEqual ===
21-
extension SimpleSequenceWithOutOfLineEqualEqual : CxxSequence {}
21+
// Conformance to CxxSequence is synthesized.
22+
checkIntSequence(SimpleSequenceWithOutOfLineEqualEqual())
2223

2324
// === SimpleArrayWrapper ===
2425
// No UnsafeCxxInputIterator conformance required, since the iterators are actually UnsafePointers here.
25-
extension SimpleArrayWrapper: CxxSequence {}
26+
// Conformance to CxxSequence is synthesized.
27+
checkIntSequence(SimpleArrayWrapper())
2628

2729
// === SimpleArrayWrapperNullableIterators ===
2830
// No UnsafeCxxInputIterator conformance required, since the iterators are actually optional UnsafePointers here.
29-
extension SimpleArrayWrapperNullableIterators: CxxSequence {}
31+
// Conformance to CxxSequence is synthesized.
32+
checkIntSequence(SimpleArrayWrapperNullableIterators())
3033

3134
// === SimpleEmptySequence ===
3235
// No UnsafeCxxInputIterator conformance required, since the iterators are actually optional UnsafePointers here.
33-
extension SimpleEmptySequence: CxxSequence {}
36+
// Conformance to CxxSequence is synthesized.
37+
checkIntSequence(SimpleEmptySequence())

test/Interop/Cxx/stdlib/overlay/custom-sequence.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,9 @@
55

66
import StdlibUnittest
77
import CustomSequence
8-
import Cxx
98

109
var CxxSequenceTestSuite = TestSuite("CxxSequence")
1110

12-
extension SimpleSequence: CxxSequence {}
13-
14-
extension SimpleEmptySequence: CxxSequence {}
15-
16-
1711
CxxSequenceTestSuite.test("SimpleSequence as Swift.Sequence") {
1812
let seq = SimpleSequence()
1913
let contains = seq.contains(where: { $0 == 3 })

0 commit comments

Comments
 (0)