Skip to content

Commit ea781bf

Browse files
authored
Merge pull request #61700 from apple/egorzhdan/cxx-conform-raciter
[cxx-interop] Synthesize conformances to `UnsafeCxxInputIterator`
2 parents b26225c + 0efd20d commit ea781bf

File tree

7 files changed

+174
-25
lines changed

7 files changed

+174
-25
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ PROTOCOL(DistributedTargetInvocationResultHandler)
107107
// C++ Standard Library Overlay:
108108
PROTOCOL(CxxSequence)
109109
PROTOCOL(UnsafeCxxInputIterator)
110+
PROTOCOL(UnsafeCxxRandomAccessIterator)
110111

111112
PROTOCOL(AsyncSequence)
112113
PROTOCOL(AsyncIteratorProtocol)

lib/AST/ASTContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
11151115
break;
11161116
case KnownProtocolKind::CxxSequence:
11171117
case KnownProtocolKind::UnsafeCxxInputIterator:
1118+
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
11181119
M = getLoadedModule(Id_Cxx);
11191120
break;
11201121
default:

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 122 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,29 @@ getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
5353
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
5454
}
5555

56-
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
57-
auto id = decl->getASTContext().Id_EqualsOperator;
56+
static ValueDecl *lookupOperator(NominalTypeDecl *decl, Identifier id,
57+
function_ref<bool(ValueDecl *)> isValid) {
58+
// First look for operator declared as a member.
59+
auto memberResults = lookupDirectWithoutExtensions(decl, id);
60+
for (const auto &member : memberResults) {
61+
if (isValid(member))
62+
return member;
63+
}
64+
65+
// If no member operator was found, look for out-of-class definitions in the
66+
// same module.
67+
auto module = decl->getModuleContext();
68+
SmallVector<ValueDecl *> nonMemberResults;
69+
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
70+
for (const auto &nonMember : nonMemberResults) {
71+
if (isValid(nonMember))
72+
return nonMember;
73+
}
5874

75+
return nullptr;
76+
}
77+
78+
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
5979
auto isValid = [&](ValueDecl *equalEqualOp) -> bool {
6080
auto equalEqual = dyn_cast<FuncDecl>(equalEqualOp);
6181
if (!equalEqual || !equalEqual->hasParameterList())
@@ -78,24 +98,72 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
7898
return true;
7999
};
80100

81-
// First look for `func ==` declared as a member.
82-
auto memberResults = lookupDirectWithoutExtensions(decl, id);
83-
for (const auto &member : memberResults) {
84-
if (isValid(member))
85-
return member;
86-
}
101+
return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid);
102+
}
87103

88-
// If no member `func ==` was found, look for out-of-class definitions in the
89-
// same module.
104+
static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
105+
auto binaryIntegerProto =
106+
decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger);
90107
auto module = decl->getModuleContext();
91-
SmallVector<ValueDecl *> nonMemberResults;
92-
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
93-
for (const auto &nonMember : nonMemberResults) {
94-
if (isValid(nonMember))
95-
return nonMember;
96-
}
97108

98-
return nullptr;
109+
auto isValid = [&](ValueDecl *minusOp) -> bool {
110+
auto minus = dyn_cast<FuncDecl>(minusOp);
111+
if (!minus || !minus->hasParameterList())
112+
return false;
113+
auto params = minus->getParameters();
114+
if (params->size() != 2)
115+
return false;
116+
auto lhs = params->get(0);
117+
auto rhs = params->get(1);
118+
if (lhs->isInOut() || rhs->isInOut())
119+
return false;
120+
auto lhsTy = lhs->getType();
121+
auto rhsTy = rhs->getType();
122+
if (!lhsTy || !rhsTy)
123+
return false;
124+
auto lhsNominal = lhsTy->getAnyNominal();
125+
auto rhsNominal = rhsTy->getAnyNominal();
126+
if (lhsNominal != rhsNominal || lhsNominal != decl)
127+
return false;
128+
auto returnTy = minus->getResultInterfaceType();
129+
if (!module->conformsToProtocol(returnTy, binaryIntegerProto))
130+
return false;
131+
return true;
132+
};
133+
134+
return lookupOperator(decl, decl->getASTContext().getIdentifier("-"),
135+
isValid);
136+
}
137+
138+
static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
139+
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
140+
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
141+
if (!plusEqual || !plusEqual->hasParameterList())
142+
return false;
143+
auto params = plusEqual->getParameters();
144+
if (params->size() != 2)
145+
return false;
146+
auto lhs = params->get(0);
147+
auto rhs = params->get(1);
148+
if (rhs->isInOut())
149+
return false;
150+
auto lhsTy = lhs->getType();
151+
auto rhsTy = rhs->getType();
152+
if (!lhsTy || !rhsTy)
153+
return false;
154+
if (rhsTy->getCanonicalType() != distanceTy->getCanonicalType())
155+
return false;
156+
auto lhsNominal = lhsTy->getAnyNominal();
157+
if (lhsNominal != decl)
158+
return false;
159+
auto returnTy = plusEqual->getResultInterfaceType();
160+
if (!returnTy->isVoid())
161+
return false;
162+
return true;
163+
};
164+
165+
return lookupOperator(decl, decl->getASTContext().getIdentifier("+="),
166+
isValid);
99167
}
100168

101169
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
@@ -111,6 +179,9 @@ void swift::conformToCxxIteratorIfNeeded(
111179
assert(clangDecl);
112180
ASTContext &ctx = decl->getASTContext();
113181

182+
if (!ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator))
183+
return;
184+
114185
// We consider a type to be an input iterator if it defines an
115186
// `iterator_category` that inherits from `std::input_iterator_tag`, e.g.
116187
// `using iterator_category = std::input_iterator_tag`.
@@ -134,17 +205,30 @@ void swift::conformToCxxIteratorIfNeeded(
134205
if (!underlyingCategoryDecl)
135206
return;
136207

137-
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
208+
auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base,
209+
StringRef tag) {
138210
return base->isInStdNamespace() && base->getIdentifier() &&
139-
base->getName() == "input_iterator_tag";
211+
base->getName() == tag;
212+
};
213+
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
214+
return isIteratorCategoryDecl(base, "input_iterator_tag");
215+
};
216+
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
217+
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
140218
};
141219

142220
// Traverse all transitive bases of `underlyingDecl` to check if
143221
// it inherits from `std::input_iterator_tag`.
144222
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
223+
bool isRandomAccessIterator =
224+
isRandomAccessIteratorDecl(underlyingCategoryDecl);
145225
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
146226
if (isInputIteratorDecl(base)) {
147227
isInputIterator = true;
228+
}
229+
if (isRandomAccessIteratorDecl(base)) {
230+
isRandomAccessIterator = true;
231+
isInputIterator = true;
148232
return false;
149233
}
150234
return true;
@@ -183,6 +267,25 @@ void swift::conformToCxxIteratorIfNeeded(
183267
pointee->getType());
184268
impl.addSynthesizedProtocolAttrs(decl,
185269
{KnownProtocolKind::UnsafeCxxInputIterator});
270+
if (!isRandomAccessIterator ||
271+
!ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator))
272+
return;
273+
274+
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
275+
276+
auto minus = dyn_cast<FuncDecl>(getMinusOperator(decl));
277+
if (!minus)
278+
return;
279+
auto distanceTy = minus->getResultInterfaceType();
280+
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
281+
282+
auto plusEqual = dyn_cast<FuncDecl>(getPlusEqualOperator(decl, distanceTy));
283+
if (!plusEqual)
284+
return;
285+
286+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Distance"), distanceTy);
287+
impl.addSynthesizedProtocolAttrs(
288+
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
186289
}
187290

188291
void swift::conformToCxxSequenceIfNeeded(

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5886,6 +5886,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
58865886
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
58875887
case KnownProtocolKind::CxxSequence:
58885888
case KnownProtocolKind::UnsafeCxxInputIterator:
5889+
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
58895890
case KnownProtocolKind::SerialExecutor:
58905891
case KnownProtocolKind::Sendable:
58915892
case KnownProtocolKind::UnsafeSendable:

test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,25 @@ struct HasCustomIteratorTag {
240240
}
241241
};
242242

243+
struct HasCustomRACIteratorTag {
244+
struct CustomTag : public std::random_access_iterator_tag {};
245+
246+
int value;
247+
using iterator_category = CustomTag;
248+
const int &operator*() const { return value; }
249+
HasCustomRACIteratorTag &operator++() {
250+
value++;
251+
return *this;
252+
}
253+
void operator+=(int x) { value += x; }
254+
int operator-(const HasCustomRACIteratorTag &x) const {
255+
return value - x.value;
256+
}
257+
bool operator==(const HasCustomRACIteratorTag &other) const {
258+
return value == other.value;
259+
}
260+
};
261+
243262
struct HasCustomIteratorTagInline {
244263
struct iterator_category : public std::input_iterator_tag {};
245264

test/Interop/Cxx/stdlib/overlay/custom-collection.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ var CxxCollectionTestSuite = TestSuite("CxxCollection")
1111

1212
// === SimpleCollectionNoSubscript ===
1313

14-
extension SimpleCollectionNoSubscript.iterator : UnsafeCxxRandomAccessIterator {
15-
public typealias Distance = difference_type
16-
}
1714
extension SimpleCollectionNoSubscript : CxxRandomAccessCollection {
1815
}
1916

@@ -25,9 +22,6 @@ CxxCollectionTestSuite.test("SimpleCollectionNoSubscript as Swift.Collection") {
2522

2623
// === SimpleCollectionReadOnly ===
2724

28-
extension SimpleCollectionReadOnly.iterator : UnsafeCxxRandomAccessIterator {
29-
public typealias Distance = difference_type
30-
}
3125
extension SimpleCollectionReadOnly : CxxRandomAccessCollection {
3226
}
3327

test/Interop/Cxx/stdlib/overlay/custom-iterator-module-interface.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,26 @@
77
// CHECK: typealias Pointee = Int32
88
// CHECK: }
99

10+
// CHECK: struct ConstRACIterator : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
11+
// CHECK: var pointee: Int32 { get }
12+
// CHECK: func successor() -> ConstRACIterator
13+
// CHECK: static func += (lhs: inout ConstRACIterator, v: ConstRACIterator.difference_type)
14+
// CHECK: static func - (lhs: ConstRACIterator, other: ConstRACIterator) -> Int32
15+
// CHECK: static func == (lhs: ConstRACIterator, other: ConstRACIterator) -> Bool
16+
// CHECK: typealias Pointee = Int32
17+
// CHECK: typealias Distance = Int32
18+
// CHECK: }
19+
20+
// CHECK: struct ConstRACIteratorRefPlusEq : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
21+
// CHECK: var pointee: Int32 { get }
22+
// CHECK: func successor() -> ConstRACIterator
23+
// CHECK: static func += (lhs: inout ConstRACIteratorRefPlusEq, v: ConstRACIteratorRefPlusEq.difference_type)
24+
// CHECK: static func - (lhs: ConstRACIteratorRefPlusEq, other: ConstRACIteratorRefPlusEq) -> Int32
25+
// CHECK: static func == (lhs: ConstRACIteratorRefPlusEq, other: ConstRACIteratorRefPlusEq) -> Bool
26+
// CHECK: typealias Pointee = Int32
27+
// CHECK: typealias Distance = Int32
28+
// CHECK: }
29+
1030
// CHECK: struct ConstIteratorOutOfLineEq : UnsafeCxxInputIterator {
1131
// CHECK: var pointee: Int32 { get }
1232
// CHECK: func successor() -> ConstIteratorOutOfLineEq
@@ -34,6 +54,16 @@
3454
// CHECK: typealias Pointee = Int32
3555
// CHECK: }
3656

57+
// CHECK: struct HasCustomRACIteratorTag : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
58+
// CHECK: var pointee: Int32 { get }
59+
// CHECK: func successor() -> HasCustomRACIteratorTag
60+
// CHECK: static func += (lhs: inout HasCustomRACIteratorTag, x: Int32)
61+
// CHECK: static func - (lhs: HasCustomRACIteratorTag, x: HasCustomRACIteratorTag) -> Int32
62+
// CHECK: static func == (lhs: HasCustomRACIteratorTag, other: HasCustomRACIteratorTag) -> Bool
63+
// CHECK: typealias Pointee = Int32
64+
// CHECK: typealias Distance = Int32
65+
// CHECK: }
66+
3767
// CHECK: struct HasCustomIteratorTagInline : UnsafeCxxInputIterator {
3868
// CHECK: var pointee: Int32 { get }
3969
// CHECK: func successor() -> HasCustomIteratorTagInline

0 commit comments

Comments
 (0)