Skip to content

Commit cbd533d

Browse files
committed
[cxx-interop] Add UnsafeCxxContiguousIterator & UnsafeCxxMutableContiguousIterator protocols
This adds a pair of Swift protocols that represents C++ iterator types conforming to `std::contiguous_iterator_tag` requirements. These are random access iterators that guarantee that the values are stored in consequent memory addresses. This will be used to optimize usage of C++ containers such as `std::vector` from Swift, for instance, by providing an overload of `withContiguousStorageIfAvailable` for contiguous containers.
1 parent a9d5903 commit cbd533d

File tree

7 files changed

+241
-0
lines changed

7 files changed

+241
-0
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ PROTOCOL(UnsafeCxxInputIterator)
142142
PROTOCOL(UnsafeCxxMutableInputIterator)
143143
PROTOCOL(UnsafeCxxRandomAccessIterator)
144144
PROTOCOL(UnsafeCxxMutableRandomAccessIterator)
145+
PROTOCOL(UnsafeCxxContiguousIterator)
146+
PROTOCOL(UnsafeCxxMutableContiguousIterator)
145147

146148
PROTOCOL(AsyncSequence)
147149
PROTOCOL(AsyncIteratorProtocol)

lib/AST/ASTContext.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,8 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
14441444
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
14451445
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
14461446
case KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator:
1447+
case KnownProtocolKind::UnsafeCxxContiguousIterator:
1448+
case KnownProtocolKind::UnsafeCxxMutableContiguousIterator:
14471449
M = getLoadedModule(Id_Cxx);
14481450
break;
14491451
case KnownProtocolKind::Copyable:

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,28 @@ void swift::conformToCxxIteratorIfNeeded(
462462
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
463463
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
464464
};
465+
auto isContiguousIteratorDecl = [&](const clang::CXXRecordDecl *base) {
466+
return isIteratorCategoryDecl(base, "contiguous_iterator_tag"); // C++20
467+
};
465468

466469
// Traverse all transitive bases of `underlyingDecl` to check if
467470
// it inherits from `std::input_iterator_tag`.
468471
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
469472
bool isRandomAccessIterator =
470473
isRandomAccessIteratorDecl(underlyingCategoryDecl);
474+
bool isContiguousIterator = isContiguousIteratorDecl(underlyingCategoryDecl);
471475
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
472476
if (isInputIteratorDecl(base)) {
473477
isInputIterator = true;
474478
}
475479
if (isRandomAccessIteratorDecl(base)) {
476480
isRandomAccessIterator = true;
477481
isInputIterator = true;
482+
}
483+
if (isContiguousIteratorDecl(base)) {
484+
isContiguousIterator = true;
485+
isRandomAccessIterator = true;
486+
isInputIterator = true;
478487
return false;
479488
}
480489
return true;
@@ -594,6 +603,15 @@ void swift::conformToCxxIteratorIfNeeded(
594603
else
595604
impl.addSynthesizedProtocolAttrs(
596605
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
606+
607+
if (isContiguousIterator) {
608+
if (pointeeSettable)
609+
impl.addSynthesizedProtocolAttrs(
610+
decl, {KnownProtocolKind::UnsafeCxxMutableContiguousIterator});
611+
else
612+
impl.addSynthesizedProtocolAttrs(
613+
decl, {KnownProtocolKind::UnsafeCxxContiguousIterator});
614+
}
597615
}
598616

599617
void swift::conformToCxxConvertibleToBoolIfNeeded(

lib/IRGen/GenMeta.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6967,6 +6967,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
69676967
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
69686968
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
69696969
case KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator:
6970+
case KnownProtocolKind::UnsafeCxxContiguousIterator:
6971+
case KnownProtocolKind::UnsafeCxxMutableContiguousIterator:
69706972
case KnownProtocolKind::Executor:
69716973
case KnownProtocolKind::SerialExecutor:
69726974
case KnownProtocolKind::TaskExecutor:

stdlib/public/Cxx/UnsafeCxxIterators.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,8 @@ public protocol UnsafeCxxMutableRandomAccessIterator:
8787
UnsafeCxxRandomAccessIterator, UnsafeCxxMutableInputIterator {}
8888

8989
extension UnsafeMutablePointer: UnsafeCxxMutableRandomAccessIterator {}
90+
91+
public protocol UnsafeCxxContiguousIterator: UnsafeCxxRandomAccessIterator {}
92+
93+
public protocol UnsafeCxxMutableContiguousIterator:
94+
UnsafeCxxContiguousIterator, UnsafeCxxMutableRandomAccessIterator {}

test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,181 @@ struct HasTypedefIteratorTag {
286286
}
287287
};
288288

289+
#if __cplusplus >= 202002L
290+
struct ConstContiguousIterator {
291+
private:
292+
const int *value;
293+
294+
public:
295+
using iterator_category = std::contiguous_iterator_tag;
296+
using value_type = int;
297+
using pointer = int *;
298+
using reference = const int &;
299+
using difference_type = int;
300+
301+
ConstContiguousIterator(const int *value) : value(value) {}
302+
ConstContiguousIterator(const ConstContiguousIterator &other) = default;
303+
304+
const int &operator*() const { return *value; }
305+
306+
ConstContiguousIterator &operator++() {
307+
value++;
308+
return *this;
309+
}
310+
ConstContiguousIterator operator++(int) {
311+
auto tmp = ConstContiguousIterator(value);
312+
value++;
313+
return tmp;
314+
}
315+
316+
void operator+=(difference_type v) { value += v; }
317+
void operator-=(difference_type v) { value -= v; }
318+
ConstContiguousIterator operator+(difference_type v) const {
319+
return ConstContiguousIterator(value + v);
320+
}
321+
ConstContiguousIterator operator-(difference_type v) const {
322+
return ConstContiguousIterator(value - v);
323+
}
324+
friend ConstContiguousIterator operator+(difference_type v,
325+
const ConstContiguousIterator &it) {
326+
return it + v;
327+
}
328+
int operator-(const ConstContiguousIterator &other) const {
329+
return value - other.value;
330+
}
331+
332+
bool operator<(const ConstContiguousIterator &other) const {
333+
return value < other.value;
334+
}
335+
336+
bool operator==(const ConstContiguousIterator &other) const {
337+
return value == other.value;
338+
}
339+
bool operator!=(const ConstContiguousIterator &other) const {
340+
return value != other.value;
341+
}
342+
};
343+
344+
struct HasCustomContiguousIteratorTag {
345+
private:
346+
const int *value;
347+
348+
public:
349+
struct CustomTag : std::contiguous_iterator_tag {};
350+
using iterator_category = CustomTag;
351+
using value_type = int;
352+
using pointer = int *;
353+
using reference = const int &;
354+
using difference_type = int;
355+
356+
HasCustomContiguousIteratorTag(const int *value) : value(value) {}
357+
HasCustomContiguousIteratorTag(const HasCustomContiguousIteratorTag &other) =
358+
default;
359+
360+
const int &operator*() const { return *value; }
361+
362+
HasCustomContiguousIteratorTag &operator++() {
363+
value++;
364+
return *this;
365+
}
366+
HasCustomContiguousIteratorTag operator++(int) {
367+
auto tmp = HasCustomContiguousIteratorTag(value);
368+
value++;
369+
return tmp;
370+
}
371+
372+
void operator+=(difference_type v) { value += v; }
373+
void operator-=(difference_type v) { value -= v; }
374+
HasCustomContiguousIteratorTag operator+(difference_type v) const {
375+
return HasCustomContiguousIteratorTag(value + v);
376+
}
377+
HasCustomContiguousIteratorTag operator-(difference_type v) const {
378+
return HasCustomContiguousIteratorTag(value - v);
379+
}
380+
friend HasCustomContiguousIteratorTag
381+
operator+(difference_type v, const HasCustomContiguousIteratorTag &it) {
382+
return it + v;
383+
}
384+
int operator-(const HasCustomContiguousIteratorTag &other) const {
385+
return value - other.value;
386+
}
387+
388+
bool operator<(const HasCustomContiguousIteratorTag &other) const {
389+
return value < other.value;
390+
}
391+
392+
bool operator==(const HasCustomContiguousIteratorTag &other) const {
393+
return value == other.value;
394+
}
395+
bool operator!=(const HasCustomContiguousIteratorTag &other) const {
396+
return value != other.value;
397+
}
398+
};
399+
400+
struct ConstContiguousIteratorInheritedFromConstRACIterator : ConstRACIterator {
401+
using iterator_category = std::contiguous_iterator_tag;
402+
};
403+
404+
struct MutableContiguousIterator {
405+
private:
406+
int *value;
407+
408+
public:
409+
using iterator_category = std::contiguous_iterator_tag;
410+
using value_type = int;
411+
using pointer = int *;
412+
using reference = const int &;
413+
using difference_type = int;
414+
415+
MutableContiguousIterator(int *value) : value(value) {}
416+
MutableContiguousIterator(const MutableRACIterator &other) = default;
417+
418+
const int &operator*() const { return *value; }
419+
int &operator*() { return *value; }
420+
421+
MutableContiguousIterator &operator++() {
422+
value++;
423+
return *this;
424+
}
425+
MutableContiguousIterator operator++(int) {
426+
auto tmp = MutableRACIterator(value);
427+
value++;
428+
return tmp;
429+
}
430+
431+
void operator+=(difference_type v) { value += v; }
432+
void operator-=(difference_type v) { value -= v; }
433+
MutableContiguousIterator operator+(difference_type v) const {
434+
return MutableContiguousIterator(value + v);
435+
}
436+
MutableContiguousIterator operator-(difference_type v) const {
437+
return MutableContiguousIterator(value - v);
438+
}
439+
friend MutableContiguousIterator
440+
operator+(difference_type v, const MutableContiguousIterator &it) {
441+
return it + v;
442+
}
443+
int operator-(const MutableContiguousIterator &other) const {
444+
return value - other.value;
445+
}
446+
447+
bool operator<(const MutableContiguousIterator &other) const {
448+
return value < other.value;
449+
}
450+
451+
bool operator==(const MutableContiguousIterator &other) const {
452+
return value == other.value;
453+
}
454+
bool operator!=(const MutableContiguousIterator &other) const {
455+
return value != other.value;
456+
}
457+
};
458+
459+
struct MutableContiguousInheritedFromMutableRACIterator : MutableRACIterator {
460+
using iterator_category = std::contiguous_iterator_tag;
461+
};
462+
#endif
463+
289464
// MARK: Types that are not actually iterators
290465

291466
struct HasNoIteratorCategory {

test/Interop/Cxx/stdlib/overlay/custom-iterator-module-interface.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomIterator -source-filename=x -I %S/Inputs -enable-experimental-cxx-interop | %FileCheck %s
22
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomIterator -source-filename=x -I %S/Inputs -cxx-interoperability-mode=swift-6 | %FileCheck %s
3+
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomIterator -source-filename=x -I %S/Inputs -cxx-interoperability-mode=swift-6 -Xcc -std=c++20 | %FileCheck %s --check-prefix=CHECK-CXX20
34
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomIterator -source-filename=x -I %S/Inputs -cxx-interoperability-mode=upcoming-swift | %FileCheck %s
5+
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomIterator -source-filename=x -I %S/Inputs -cxx-interoperability-mode=upcoming-swift -Xcc -std=c++20 | %FileCheck %s --check-prefix=CHECK-CXX20
46

57
// CHECK: struct ConstIterator : UnsafeCxxInputIterator {
68
// CHECK: func successor() -> ConstIterator
@@ -80,6 +82,41 @@
8082
// CHECK: static func == (lhs: HasTypedefIteratorTag, other: HasTypedefIteratorTag) -> Bool
8183
// CHECK: }
8284

85+
// CHECK-CXX20: struct ConstContiguousIterator : UnsafeCxxContiguousIterator, UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
86+
// CHECK-CXX20: func successor() -> ConstContiguousIterator
87+
// CHECK-CXX20: var pointee: Int32
88+
// CHECK-CXX20: typealias Pointee = Int32
89+
// CHECK-CXX20: typealias Distance = Int32
90+
// CHECK-CXX20: }
91+
92+
// CHECK-CXX20: struct HasCustomContiguousIteratorTag : UnsafeCxxContiguousIterator, UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
93+
// CHECK-CXX20: func successor() -> HasCustomContiguousIteratorTag
94+
// CHECK-CXX20: var pointee: Int32
95+
// CHECK-CXX20: typealias Pointee = Int32
96+
// CHECK-CXX20: typealias Distance = Int32
97+
// CHECK-CXX20: }
98+
99+
// CHECK-CXX20: struct ConstContiguousIteratorInheritedFromConstRACIterator : UnsafeCxxContiguousIterator, UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
100+
// CHECK-CXX20: func successor() -> ConstContiguousIteratorInheritedFromConstRACIterator
101+
// CHECK-CXX20: var pointee: Int32
102+
// CHECK-CXX20: typealias Pointee = Int32
103+
// CHECK-CXX20: typealias Distance = Int32
104+
// CHECK-CXX20: }
105+
106+
// CHECK-CXX20: struct MutableContiguousIterator : UnsafeCxxMutableContiguousIterator, UnsafeCxxMutableRandomAccessIterator, UnsafeCxxMutableInputIterator {
107+
// CHECK-CXX20: func successor() -> MutableContiguousIterator
108+
// CHECK-CXX20: var pointee: Int32
109+
// CHECK-CXX20: typealias Pointee = Int32
110+
// CHECK-CXX20: typealias Distance = Int32
111+
// CHECK-CXX20: }
112+
113+
// CHECK-CXX20: struct MutableContiguousInheritedFromMutableRACIterator : UnsafeCxxMutableContiguousIterator, UnsafeCxxMutableRandomAccessIterator, UnsafeCxxMutableInputIterator {
114+
// CHECK-CXX20: func successor() -> MutableContiguousInheritedFromMutableRACIterator
115+
// CHECK-CXX20: var pointee: Int32
116+
// CHECK-CXX20: typealias Pointee = Int32
117+
// CHECK-CXX20: typealias Distance = Int32
118+
// CHECK-CXX20: }
119+
83120
// CHECK-NOT: struct HasNoIteratorCategory : UnsafeCxxInputIterator
84121
// CHECK-NOT: struct HasInvalidIteratorCategory : UnsafeCxxInputIterator
85122
// CHECK-NOT: struct HasNoEqualEqual : UnsafeCxxInputIterator

0 commit comments

Comments
 (0)