Skip to content

Commit 109d9c1

Browse files
authored
Merge pull request #60155 from apple/egorzhdan/cxx-iterator-ool-equals
[cxx-interop] Allow iterators with out-of-class `operator==`
2 parents a83a17c + 88167d7 commit 109d9c1

File tree

4 files changed

+63
-22
lines changed

4 files changed

+63
-22
lines changed

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,51 @@ getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
2828
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
2929
}
3030

31+
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
32+
auto id = decl->getASTContext().Id_EqualsOperator;
33+
34+
auto isValid = [&](ValueDecl *equalEqualOp) -> bool {
35+
auto equalEqual = dyn_cast<FuncDecl>(equalEqualOp);
36+
if (!equalEqual || !equalEqual->hasParameterList())
37+
return false;
38+
auto params = equalEqual->getParameters();
39+
if (params->size() != 2)
40+
return false;
41+
auto lhs = params->get(0);
42+
auto rhs = params->get(1);
43+
if (lhs->isInOut() || rhs->isInOut())
44+
return false;
45+
auto lhsTy = lhs->getType();
46+
auto rhsTy = rhs->getType();
47+
if (!lhsTy || !rhsTy)
48+
return false;
49+
auto lhsNominal = lhsTy->getAnyNominal();
50+
auto rhsNominal = rhsTy->getAnyNominal();
51+
if (lhsNominal != rhsNominal || lhsNominal != decl)
52+
return false;
53+
return true;
54+
};
55+
56+
// First look for `func ==` declared as a member.
57+
auto memberResults = decl->lookupDirect(id);
58+
for (const auto &member : memberResults) {
59+
if (isValid(member))
60+
return member;
61+
}
62+
63+
// If no member `func ==` was found, look for out-of-class definitions in the
64+
// same module.
65+
auto module = decl->getModuleContext();
66+
llvm::SmallVector<ValueDecl *> nonMemberResults;
67+
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
68+
for (const auto &nonMember : nonMemberResults) {
69+
if (isValid(nonMember))
70+
return nonMember;
71+
}
72+
73+
return nullptr;
74+
}
75+
3176
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
3277
return getIteratorCategoryDecl(clangDecl);
3378
}
@@ -103,24 +148,8 @@ void swift::conformToCxxIteratorIfNeeded(
103148
return;
104149

105150
// Check if present: `func ==`
106-
// FIXME: this only detects `operator==` declared as a member.
107-
auto equalEquals = decl->lookupDirect(ctx.Id_EqualsOperator);
108-
if (equalEquals.empty())
109-
return;
110-
auto equalEqual = dyn_cast<FuncDecl>(equalEquals.front());
111-
if (!equalEqual || !equalEqual->hasParameterList())
112-
return;
113-
auto equalEqualParams = equalEqual->getParameters();
114-
if (equalEqualParams->size() != 2)
115-
return;
116-
auto equalEqualLHS = equalEqualParams->get(0);
117-
auto equalEqualRHS = equalEqualParams->get(1);
118-
if (equalEqualLHS->isInOut() || equalEqualRHS->isInOut())
119-
return;
120-
auto equalEqualLHSTy = equalEqualLHS->getType();
121-
auto equalEqualRHSTy = equalEqualRHS->getType();
122-
if (!equalEqualLHSTy || !equalEqualRHSTy ||
123-
equalEqualLHSTy->getAnyNominal() != equalEqualRHSTy->getAnyNominal())
151+
auto equalEqual = getEqualEqualOperator(decl);
152+
if (!equalEqual)
124153
return;
125154

126155
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Pointee"),

test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,19 @@ struct HasNoEqualEqual {
191191
}
192192
};
193193

194+
struct HasInvalidEqualEqual {
195+
int value;
196+
using iterator_category = std::input_iterator_tag;
197+
const int &operator*() const { return value; }
198+
HasInvalidEqualEqual &operator++() {
199+
value++;
200+
return *this;
201+
}
202+
bool operator==(const int &other) const { // wrong type
203+
return value == other;
204+
}
205+
};
206+
194207
struct HasNoIncrementOperator {
195208
int value;
196209
using iterator_category = std::input_iterator_tag;

test/Interop/Cxx/stdlib/overlay/custom-iterator-module-interface.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
// CHECK: static func == (lhs: ConstIterator, other: ConstIterator) -> Bool
1010
// CHECK: }
1111

12-
// TODO: ConstIteratorOutOfLineEq should also conform to UnsafeCxxInputIterator.
13-
// CHECK: struct ConstIteratorOutOfLineEq {
12+
// CHECK: struct ConstIteratorOutOfLineEq : UnsafeCxxInputIterator {
1413
// CHECK: var pointee: Int32 { get }
1514
// CHECK: func successor() -> ConstIteratorOutOfLineEq
1615
// CHECK: }
@@ -54,6 +53,7 @@
5453
// CHECK-NOT: struct HasNoIteratorCategory : UnsafeCxxInputIterator
5554
// CHECK-NOT: struct HasInvalidIteratorCategory : UnsafeCxxInputIterator
5655
// CHECK-NOT: struct HasNoEqualEqual : UnsafeCxxInputIterator
56+
// CHECK-NOT: struct HasInvalidEqualEqual : UnsafeCxxInputIterator
5757
// CHECK-NOT: struct HasNoIncrementOperator : UnsafeCxxInputIterator
5858
// CHECK-NOT: struct HasNoPreIncrementOperator : UnsafeCxxInputIterator
5959
// CHECK-NOT: struct HasNoDereferenceOperator : UnsafeCxxInputIterator

test/Interop/Cxx/stdlib/overlay/custom-sequence-typechecker.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ func checkSimpleSequence() {
1818
}
1919

2020
// === SimpleSequenceWithOutOfLineEqualEqual ===
21-
// TODO: synthesize conformance to UnsafeCxxInputIterator.
22-
//extension SimpleSequenceWithOutOfLineEqualEqual : CxxSequence {}
21+
extension SimpleSequenceWithOutOfLineEqualEqual : CxxSequence {}
2322

2423
// === SimpleArrayWrapper ===
2524
// No UnsafeCxxInputIterator conformance required, since the iterators are actually UnsafePointers here.

0 commit comments

Comments
 (0)