@@ -53,9 +53,29 @@ getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
53
53
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
54
54
}
55
55
56
- static ValueDecl *getEqualEqualOperator (NominalTypeDecl *decl) {
57
- auto id = decl->getASTContext ().Id_EqualsOperator ;
56
+ static ValueDecl *lookupOperator (NominalTypeDecl *decl, Identifier id,
57
+ function_ref<bool (ValueDecl *)> isValid) {
58
+ // First look for operator declared as a member.
59
+ auto memberResults = lookupDirectWithoutExtensions (decl, id);
60
+ for (const auto &member : memberResults) {
61
+ if (isValid (member))
62
+ return member;
63
+ }
64
+
65
+ // If no member operator was found, look for out-of-class definitions in the
66
+ // same module.
67
+ auto module = decl->getModuleContext ();
68
+ SmallVector<ValueDecl *> nonMemberResults;
69
+ module ->lookupValue (id, NLKind::UnqualifiedLookup, nonMemberResults);
70
+ for (const auto &nonMember : nonMemberResults) {
71
+ if (isValid (nonMember))
72
+ return nonMember;
73
+ }
58
74
75
+ return nullptr ;
76
+ }
77
+
78
+ static ValueDecl *getEqualEqualOperator (NominalTypeDecl *decl) {
59
79
auto isValid = [&](ValueDecl *equalEqualOp) -> bool {
60
80
auto equalEqual = dyn_cast<FuncDecl>(equalEqualOp);
61
81
if (!equalEqual || !equalEqual->hasParameterList ())
@@ -78,24 +98,72 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
78
98
return true ;
79
99
};
80
100
81
- // First look for `func ==` declared as a member.
82
- auto memberResults = lookupDirectWithoutExtensions (decl, id);
83
- for (const auto &member : memberResults) {
84
- if (isValid (member))
85
- return member;
86
- }
101
+ return lookupOperator (decl, decl->getASTContext ().Id_EqualsOperator , isValid);
102
+ }
87
103
88
- // If no member `func ==` was found, look for out-of-class definitions in the
89
- // same module.
104
+ static ValueDecl *getMinusOperator (NominalTypeDecl *decl) {
105
+ auto binaryIntegerProto =
106
+ decl->getASTContext ().getProtocol (KnownProtocolKind::BinaryInteger);
90
107
auto module = decl->getModuleContext ();
91
- SmallVector<ValueDecl *> nonMemberResults;
92
- module ->lookupValue (id, NLKind::UnqualifiedLookup, nonMemberResults);
93
- for (const auto &nonMember : nonMemberResults) {
94
- if (isValid (nonMember))
95
- return nonMember;
96
- }
97
108
98
- return nullptr ;
109
+ auto isValid = [&](ValueDecl *minusOp) -> bool {
110
+ auto minus = dyn_cast<FuncDecl>(minusOp);
111
+ if (!minus || !minus->hasParameterList ())
112
+ return false ;
113
+ auto params = minus->getParameters ();
114
+ if (params->size () != 2 )
115
+ return false ;
116
+ auto lhs = params->get (0 );
117
+ auto rhs = params->get (1 );
118
+ if (lhs->isInOut () || rhs->isInOut ())
119
+ return false ;
120
+ auto lhsTy = lhs->getType ();
121
+ auto rhsTy = rhs->getType ();
122
+ if (!lhsTy || !rhsTy)
123
+ return false ;
124
+ auto lhsNominal = lhsTy->getAnyNominal ();
125
+ auto rhsNominal = rhsTy->getAnyNominal ();
126
+ if (lhsNominal != rhsNominal || lhsNominal != decl)
127
+ return false ;
128
+ auto returnTy = minus->getResultInterfaceType ();
129
+ if (!module ->conformsToProtocol (returnTy, binaryIntegerProto))
130
+ return false ;
131
+ return true ;
132
+ };
133
+
134
+ return lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ),
135
+ isValid);
136
+ }
137
+
138
+ static ValueDecl *getPlusEqualOperator (NominalTypeDecl *decl, Type distanceTy) {
139
+ auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
140
+ auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
141
+ if (!plusEqual || !plusEqual->hasParameterList ())
142
+ return false ;
143
+ auto params = plusEqual->getParameters ();
144
+ if (params->size () != 2 )
145
+ return false ;
146
+ auto lhs = params->get (0 );
147
+ auto rhs = params->get (1 );
148
+ if (rhs->isInOut ())
149
+ return false ;
150
+ auto lhsTy = lhs->getType ();
151
+ auto rhsTy = rhs->getType ();
152
+ if (!lhsTy || !rhsTy)
153
+ return false ;
154
+ if (rhsTy->getCanonicalType () != distanceTy->getCanonicalType ())
155
+ return false ;
156
+ auto lhsNominal = lhsTy->getAnyNominal ();
157
+ if (lhsNominal != decl)
158
+ return false ;
159
+ auto returnTy = plusEqual->getResultInterfaceType ();
160
+ if (!returnTy->isVoid ())
161
+ return false ;
162
+ return true ;
163
+ };
164
+
165
+ return lookupOperator (decl, decl->getASTContext ().getIdentifier (" +=" ),
166
+ isValid);
99
167
}
100
168
101
169
bool swift::isIterator (const clang::CXXRecordDecl *clangDecl) {
@@ -111,6 +179,9 @@ void swift::conformToCxxIteratorIfNeeded(
111
179
assert (clangDecl);
112
180
ASTContext &ctx = decl->getASTContext ();
113
181
182
+ if (!ctx.getProtocol (KnownProtocolKind::UnsafeCxxInputIterator))
183
+ return ;
184
+
114
185
// We consider a type to be an input iterator if it defines an
115
186
// `iterator_category` that inherits from `std::input_iterator_tag`, e.g.
116
187
// `using iterator_category = std::input_iterator_tag`.
@@ -134,17 +205,30 @@ void swift::conformToCxxIteratorIfNeeded(
134
205
if (!underlyingCategoryDecl)
135
206
return ;
136
207
137
- auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
208
+ auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base,
209
+ StringRef tag) {
138
210
return base->isInStdNamespace () && base->getIdentifier () &&
139
- base->getName () == " input_iterator_tag" ;
211
+ base->getName () == tag;
212
+ };
213
+ auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
214
+ return isIteratorCategoryDecl (base, " input_iterator_tag" );
215
+ };
216
+ auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
217
+ return isIteratorCategoryDecl (base, " random_access_iterator_tag" );
140
218
};
141
219
142
220
// Traverse all transitive bases of `underlyingDecl` to check if
143
221
// it inherits from `std::input_iterator_tag`.
144
222
bool isInputIterator = isInputIteratorDecl (underlyingCategoryDecl);
223
+ bool isRandomAccessIterator =
224
+ isRandomAccessIteratorDecl (underlyingCategoryDecl);
145
225
underlyingCategoryDecl->forallBases ([&](const clang::CXXRecordDecl *base) {
146
226
if (isInputIteratorDecl (base)) {
147
227
isInputIterator = true ;
228
+ }
229
+ if (isRandomAccessIteratorDecl (base)) {
230
+ isRandomAccessIterator = true ;
231
+ isInputIterator = true ;
148
232
return false ;
149
233
}
150
234
return true ;
@@ -183,6 +267,25 @@ void swift::conformToCxxIteratorIfNeeded(
183
267
pointee->getType ());
184
268
impl.addSynthesizedProtocolAttrs (decl,
185
269
{KnownProtocolKind::UnsafeCxxInputIterator});
270
+ if (!isRandomAccessIterator ||
271
+ !ctx.getProtocol (KnownProtocolKind::UnsafeCxxRandomAccessIterator))
272
+ return ;
273
+
274
+ // Try to conform to UnsafeCxxRandomAccessIterator if possible.
275
+
276
+ auto minus = dyn_cast<FuncDecl>(getMinusOperator (decl));
277
+ if (!minus)
278
+ return ;
279
+ auto distanceTy = minus->getResultInterfaceType ();
280
+ // distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
281
+
282
+ auto plusEqual = dyn_cast<FuncDecl>(getPlusEqualOperator (decl, distanceTy));
283
+ if (!plusEqual)
284
+ return ;
285
+
286
+ impl.addSynthesizedTypealias (decl, ctx.getIdentifier (" Distance" ), distanceTy);
287
+ impl.addSynthesizedProtocolAttrs (
288
+ decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
186
289
}
187
290
188
291
void swift::conformToCxxSequenceIfNeeded (
0 commit comments