Skip to content

[cxx-interop] Add UnsafeCxxMutableInputIterator protocol #67536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ PROTOCOL(CxxRandomAccessCollection)
PROTOCOL(CxxSequence)
PROTOCOL(CxxUniqueSet)
PROTOCOL(UnsafeCxxInputIterator)
PROTOCOL(UnsafeCxxMutableInputIterator)
PROTOCOL(UnsafeCxxRandomAccessIterator)

PROTOCOL(AsyncSequence)
Expand Down
1 change: 1 addition & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::CxxUniqueSet:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
M = getLoadedModule(Id_Cxx);
break;
Expand Down
14 changes: 12 additions & 2 deletions lib/ClangImporter/ClangDerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,11 @@ void swift::conformToCxxIteratorIfNeeded(
if (!pointee || pointee->isGetterMutating() || pointee->getType()->hasError())
return;

// Check if `var pointee: Pointee` is settable. This is required for the
// conformance to UnsafeCxxMutableInputIterator but is not necessary for
// UnsafeCxxInputIterator.
bool pointeeSettable = pointee->isSettable(nullptr);

// Check if present: `func successor() -> Self`
auto successorId = ctx.getIdentifier("successor");
auto successor =
Expand Down Expand Up @@ -469,8 +474,13 @@ void swift::conformToCxxIteratorIfNeeded(

impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Pointee"),
pointee->getType());
impl.addSynthesizedProtocolAttrs(decl,
{KnownProtocolKind::UnsafeCxxInputIterator});
if (pointeeSettable)
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxMutableInputIterator});
else
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxInputIterator});

if (!isRandomAccessIterator ||
!ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator))
return;
Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6311,6 +6311,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::CxxUniqueSet:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
case KnownProtocolKind::Executor:
case KnownProtocolKind::SerialExecutor:
Expand Down
4 changes: 4 additions & 0 deletions stdlib/public/Cxx/UnsafeCxxIterators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ extension Optional: UnsafeCxxInputIterator where Wrapped: UnsafeCxxInputIterator
}
}

public protocol UnsafeCxxMutableInputIterator: UnsafeCxxInputIterator {
override var pointee: Pointee { get set }
}

/// Bridged C++ iterator that allows computing the distance between two of its
/// instances, and advancing an instance by a given number of elements.
///
Expand Down
56 changes: 56 additions & 0 deletions test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -866,4 +866,60 @@ struct InputOutputIterator {
}
};

struct MutableRACIterator {
private:
int value;

public:
struct iterator_category : std::random_access_iterator_tag,
std::output_iterator_tag {};
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;

MutableRACIterator(int value) : value(value) {}
MutableRACIterator(const MutableRACIterator &other) = default;

const int &operator*() const { return value; }
int &operator*() { return value; }

MutableRACIterator &operator++() {
value++;
return *this;
}
MutableRACIterator operator++(int) {
auto tmp = MutableRACIterator(value);
value++;
return tmp;
}

void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
MutableRACIterator operator+(difference_type v) const {
return MutableRACIterator(value + v);
}
MutableRACIterator operator-(difference_type v) const {
return MutableRACIterator(value - v);
}
friend MutableRACIterator operator+(difference_type v,
const MutableRACIterator &it) {
return it + v;
}
int operator-(const MutableRACIterator &other) const {
return value - other.value;
}

bool operator<(const MutableRACIterator &other) const {
return value < other.value;
}

bool operator==(const MutableRACIterator &other) const {
return value == other.value;
}
bool operator!=(const MutableRACIterator &other) const {
return value != other.value;
}
};

#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_ITERATOR_H
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,15 @@
// CHECK: struct InheritedTemplatedConstRACIteratorOutOfLineOps<Int32> : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: }

// CHECK: struct InputOutputIterator : UnsafeCxxInputIterator {
// CHECK: struct InputOutputIterator : UnsafeCxxMutableInputIterator {
// CHECK: func successor() -> InputOutputIterator
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: }

// CHECK: struct MutableRACIterator : UnsafeCxxRandomAccessIterator, UnsafeCxxMutableInputIterator {
// CHECK: func successor() -> MutableRACIterator
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }