Skip to content

Commit 1f920d5

Browse files
committed
[cxx-interop] Instantiate templated operator== for iterator types
C++ iterator types are often templated, and sometimes declare `operator==` as a non-member templated function. In libc++, an example of this is `__wrap_iter` which is used as an iterator type for `std::vector` and `std::string`. We don't currently import templated non-member operators into Swift, however, we still want to support common C++ iterator patterns. This change adds logic to instantiate templated non-member `operator==` for types that define `iterator_category` and are therefore likely to be valid iterator types. rdar://97915515
1 parent ac10f2a commit 1f920d5

10 files changed

+178
-29
lines changed

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "swift/AST/PrettyStackTrace.h"
1616
#include "swift/AST/ProtocolConformance.h"
1717
#include "swift/ClangImporter/ClangImporterRequests.h"
18+
#include "clang/Sema/DelayedDiagnostic.h"
19+
#include "clang/Sema/Overload.h"
1820

1921
using namespace swift;
2022
using namespace swift::importer;
@@ -202,6 +204,44 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
202204
isValid);
203205
}
204206

207+
static void instantiateTemplatedOperator(
208+
ClangImporter::Implementation &impl,
209+
const clang::ClassTemplateSpecializationDecl *classDecl,
210+
clang::BinaryOperatorKind operatorKind) {
211+
212+
clang::ASTContext &clangCtx = impl.getClangASTContext();
213+
clang::Sema &clangSema = impl.getClangSema();
214+
215+
clang::UnresolvedSet<1> ops;
216+
auto qualType = clang::QualType(classDecl->getTypeForDecl(), 0);
217+
auto arg = new (clangCtx)
218+
clang::CXXThisExpr(clang::SourceLocation(), qualType, false);
219+
arg->setType(clang::QualType(classDecl->getTypeForDecl(), 0));
220+
221+
clang::OverloadedOperatorKind opKind =
222+
clang::BinaryOperator::getOverloadedOperator(operatorKind);
223+
clang::OverloadCandidateSet candidateSet(
224+
classDecl->getLocation(), clang::OverloadCandidateSet::CSK_Operator,
225+
clang::OverloadCandidateSet::OperatorRewriteInfo(opKind, false));
226+
clangSema.LookupOverloadedBinOp(candidateSet, opKind, ops, {arg, arg}, true);
227+
228+
clang::OverloadCandidateSet::iterator best;
229+
switch (candidateSet.BestViableFunction(clangSema, clang::SourceLocation(),
230+
best)) {
231+
case clang::OR_Success: {
232+
if (auto clangCallee = best->Function) {
233+
auto lookupTable = impl.findLookupTable(classDecl);
234+
addEntryToLookupTable(*lookupTable, clangCallee, impl.getNameImporter());
235+
}
236+
break;
237+
}
238+
case clang::OR_No_Viable_Function:
239+
case clang::OR_Ambiguous:
240+
case clang::OR_Deleted:
241+
break;
242+
}
243+
}
244+
205245
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
206246
return getIteratorCategoryDecl(clangDecl);
207247
}
@@ -294,6 +334,13 @@ void swift::conformToCxxIteratorIfNeeded(
294334
if (!successorTy || successorTy->getAnyNominal() != decl)
295335
return;
296336

337+
// If this is a templated class, `operator==` might be templated as well.
338+
// Try to instantiate it.
339+
if (auto templateSpec =
340+
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
341+
instantiateTemplatedOperator(impl, templateSpec,
342+
clang::BinaryOperatorKind::BO_EQ);
343+
}
297344
// Check if present: `func ==`
298345
auto equalEqual = getEqualEqualOperator(decl);
299346
if (!equalEqual)
@@ -309,6 +356,11 @@ void swift::conformToCxxIteratorIfNeeded(
309356

310357
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
311358

359+
if (auto templateSpec =
360+
dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
361+
instantiateTemplatedOperator(impl, templateSpec,
362+
clang::BinaryOperatorKind::BO_Sub);
363+
}
312364
auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator(decl));
313365
if (!minus)
314366
return;

lib/ClangImporter/ClangImporter.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2661,7 +2661,17 @@ getClangOwningModule(ClangNode Node, const clang::ASTContext &ClangCtx) {
26612661
if (const clang::Decl *D = Node.getAsDecl()) {
26622662
auto ExtSource = ClangCtx.getExternalSource();
26632663
assert(ExtSource);
2664-
return ExtSource->getModule(D->getOwningModuleID());
2664+
2665+
auto originalDecl = D;
2666+
if (auto functionDecl = dyn_cast<clang::FunctionDecl>(D)) {
2667+
if (auto pattern = functionDecl->getTemplateInstantiationPattern()) {
2668+
// Function template instantiations don't have an owning Clang module.
2669+
// Let's use the owning module of the template pattern.
2670+
originalDecl = pattern;
2671+
}
2672+
}
2673+
2674+
return ExtSource->getModule(originalDecl->getOwningModuleID());
26652675
}
26662676

26672677
if (const clang::ModuleMacro *M = Node.getAsModuleMacro())

test/Interop/Cxx/stdlib/libcxx-module-interface.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010

1111
// CHECK-IOSFWD: enum std {
1212
// CHECK-IOSFWD: enum __1 {
13-
// CHECK-IOSFWD: struct __CxxTemplateInstNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEEE {
13+
// CHECK-IOSFWD: struct __CxxTemplateInstNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEEE : CxxRandomAccessCollection {
1414
// CHECK-IOSFWD: typealias value_type = CChar
1515
// CHECK-IOSFWD: }
16-
// CHECK-IOSFWD: struct __CxxTemplateInstNSt3__112basic_stringIwNS_11char_traitsIwEENS_9allocatorIwEEEE {
16+
// CHECK-IOSFWD: struct __CxxTemplateInstNSt3__112basic_stringIwNS_11char_traitsIwEENS_9allocatorIwEEEE : CxxRandomAccessCollection {
1717
// CHECK-IOSFWD: typealias value_type = CWideChar
1818
// CHECK-IOSFWD: }
1919
// CHECK-IOSFWD: typealias string = std.__1.__CxxTemplateInstNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEEE

test/Interop/Cxx/stdlib/libstdcxx-module-interface.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
// REQUIRES: OS=linux-gnu
1212

1313
// CHECK-STD: enum std {
14-
// CHECK-STRING: struct {{__CxxTemplateInstNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE|__CxxTemplateInstSs}} {
14+
// CHECK-STRING: struct {{__CxxTemplateInstNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE|__CxxTemplateInstSs}} : CxxRandomAccessCollection {
1515
// CHECK-STRING: typealias value_type = std.__CxxTemplateInstSt11char_traitsIcE.char_type
1616
// CHECK-STRING: }
17-
// CHECK-STRING: struct {{__CxxTemplateInstNSt7__cxx1112basic_stringIwSt11char_traitsIwESaIwEEE|__CxxTemplateInstSbIwSt11char_traitsIwESaIwEE}} {
17+
// CHECK-STRING: struct {{__CxxTemplateInstNSt7__cxx1112basic_stringIwSt11char_traitsIwESaIwEEE|__CxxTemplateInstSbIwSt11char_traitsIwESaIwEE}} : CxxRandomAccessCollection {
1818
// CHECK-STRING: typealias value_type = std.__CxxTemplateInstSt11char_traitsIwE.char_type
1919
// CHECK-STRING: }
2020

test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,4 +496,101 @@ struct TemplatedIterator {
496496

497497
using TemplatedIteratorInt = TemplatedIterator<int>;
498498

499+
template <typename T>
500+
struct TemplatedIteratorOutOfLineEq {
501+
T value;
502+
503+
using iterator_category = std::input_iterator_tag;
504+
using value_type = T;
505+
using pointer = T *;
506+
using reference = const T &;
507+
using difference_type = int;
508+
509+
TemplatedIteratorOutOfLineEq(int value) : value(value) {}
510+
TemplatedIteratorOutOfLineEq(const TemplatedIteratorOutOfLineEq &other) =
511+
default;
512+
513+
const int &operator*() const { return value; }
514+
515+
TemplatedIteratorOutOfLineEq &operator++() {
516+
value++;
517+
return *this;
518+
}
519+
TemplatedIteratorOutOfLineEq operator++(int) {
520+
auto tmp = TemplatedIteratorOutOfLineEq(value);
521+
value++;
522+
return tmp;
523+
}
524+
};
525+
526+
template <typename T>
527+
bool operator==(const TemplatedIteratorOutOfLineEq<T> &lhs,
528+
const TemplatedIteratorOutOfLineEq<T> &rhs) {
529+
return lhs.value == rhs.value;
530+
}
531+
532+
using TemplatedIteratorOutOfLineEqInt = TemplatedIteratorOutOfLineEq<int>;
533+
534+
template <typename T>
535+
struct TemplatedRACIteratorOutOfLineEq {
536+
T value;
537+
538+
using iterator_category = std::random_access_iterator_tag;
539+
using value_type = T;
540+
using pointer = T *;
541+
using reference = const T &;
542+
using difference_type = int;
543+
544+
TemplatedRACIteratorOutOfLineEq(int value) : value(value) {}
545+
TemplatedRACIteratorOutOfLineEq(const TemplatedRACIteratorOutOfLineEq &other) =
546+
default;
547+
548+
const int &operator*() const { return value; }
549+
550+
TemplatedRACIteratorOutOfLineEq &operator++() {
551+
value++;
552+
return *this;
553+
}
554+
TemplatedRACIteratorOutOfLineEq &operator--() {
555+
value--;
556+
return *this;
557+
}
558+
TemplatedRACIteratorOutOfLineEq operator++(int) {
559+
auto tmp = TemplatedRACIteratorOutOfLineEq(value);
560+
value++;
561+
return tmp;
562+
}
563+
564+
TemplatedRACIteratorOutOfLineEq &operator+=(difference_type v) {
565+
value += v;
566+
return *this;
567+
}
568+
TemplatedRACIteratorOutOfLineEq &operator-=(difference_type v) {
569+
value -= v;
570+
return *this;
571+
}
572+
573+
TemplatedRACIteratorOutOfLineEq operator+(difference_type v) const {
574+
return TemplatedRACIteratorOutOfLineEq(value + v);
575+
}
576+
TemplatedRACIteratorOutOfLineEq operator-(difference_type v) const {
577+
return TemplatedRACIteratorOutOfLineEq(value - v);
578+
}
579+
};
580+
581+
template <typename T>
582+
bool operator==(const TemplatedRACIteratorOutOfLineEq<T> &lhs,
583+
const TemplatedRACIteratorOutOfLineEq<T> &rhs) {
584+
return lhs.value == rhs.value;
585+
}
586+
587+
template <typename T>
588+
typename TemplatedRACIteratorOutOfLineEq<T>::difference_type
589+
operator-(const TemplatedRACIteratorOutOfLineEq<T> &lhs,
590+
const TemplatedRACIteratorOutOfLineEq<T> &rhs) {
591+
return lhs.value - rhs.value;
592+
}
593+
594+
using TemplatedRACIteratorOutOfLineEqInt = TemplatedRACIteratorOutOfLineEq<int>;
595+
499596
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_ITERATOR_H

test/Interop/Cxx/stdlib/overlay/custom-iterator-module-interface.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,16 @@
9292
// CHECK: typealias Pointee = Int32
9393
// CHECK: static func == (lhs: __CxxTemplateInst17TemplatedIteratorIiE, other: __CxxTemplateInst17TemplatedIteratorIiE) -> Bool
9494
// CHECK: }
95+
96+
// CHECK: struct __CxxTemplateInst28TemplatedIteratorOutOfLineEqIiE : UnsafeCxxInputIterator {
97+
// CHECK: var pointee: Int32 { get }
98+
// CHECK: func successor() -> __CxxTemplateInst28TemplatedIteratorOutOfLineEqIiE
99+
// CHECK: typealias Pointee = Int32
100+
// CHECK: }
101+
102+
// CHECK: struct __CxxTemplateInst31TemplatedRACIteratorOutOfLineEqIiE : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
103+
// CHECK: var pointee: Int32 { get }
104+
// CHECK: func successor() -> __CxxTemplateInst31TemplatedRACIteratorOutOfLineEqIiE
105+
// CHECK: typealias Pointee = Int32
106+
// CHECK: typealias Distance = __CxxTemplateInst31TemplatedRACIteratorOutOfLineEqIiE.difference_type
107+
// CHECK: }

test/Interop/Cxx/stdlib/overlay/std-string-overlay.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,6 @@ StdStringOverlayTestSuite.test("std::string <=> Swift.String") {
3737
expectEqual(swift6, "xyz\0abc")
3838
}
3939

40-
extension std.string.const_iterator: UnsafeCxxInputIterator {
41-
// This func should not be required.
42-
public static func ==(lhs: std.string.const_iterator,
43-
rhs: std.string.const_iterator) -> Bool {
44-
#if os(Linux)
45-
// In libstdc++, `base()` returns UnsafePointer<Optional<UnsafePointer<CChar>>>.
46-
return lhs.__baseUnsafe().pointee == rhs.__baseUnsafe().pointee
47-
#else
48-
// In libc++, `base()` returns UnsafePointer<CChar>.
49-
return lhs.__baseUnsafe() == rhs.__baseUnsafe()
50-
#endif
51-
}
52-
}
53-
54-
extension std.string: CxxSequence {}
55-
5640
StdStringOverlayTestSuite.test("std::string as Swift.Sequence") {
5741
let cxx1 = std.string()
5842
var iterated = false

test/Interop/Cxx/stdlib/use-std-map.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ StdMapTestSuite.test("subscript") {
2828
expectEqual(m[3], 3)
2929
}
3030

31-
extension Map.const_iterator : UnsafeCxxInputIterator { }
3231
extension Map : CxxSequence { }
3332

3433
StdMapTestSuite.test("first(where:)") {

test/Interop/Cxx/stdlib/use-std-set.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import Cxx
1212

1313
var StdSetTestSuite = TestSuite("StdSet")
1414

15-
extension SetOfCInt.const_iterator : UnsafeCxxInputIterator { }
1615
extension SetOfCInt : CxxSequence { }
1716

1817
StdSetTestSuite.test("iterate") {

test/Interop/Cxx/stdlib/use-std-vector.swift

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-I %S/Inputs -Xfrontend -enable-experimental-cxx-interop)
1+
// RUN: %target-run-simple-swift(-I %S/Inputs -Xfrontend -enable-experimental-cxx-interop -Xfrontend -validate-tbd-against-ir=none)
22
//
33
// REQUIRES: executable_test
44
//
@@ -11,11 +11,6 @@ import CxxStdlib.vector
1111

1212
var StdVectorTestSuite = TestSuite("StdVector")
1313

14-
extension Vector : RandomAccessCollection {
15-
public var startIndex: Int { 0 }
16-
public var endIndex: Int { size() }
17-
}
18-
1914
StdVectorTestSuite.test("init") {
2015
let v = Vector()
2116
expectEqual(v.size(), 0)

0 commit comments

Comments
 (0)