Skip to content

[cxx-interop] Add UnsafeCxxContiguousIterator & UnsafeCxxMutableContiguousIterator protocols #77006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ PROTOCOL(UnsafeCxxInputIterator)
PROTOCOL(UnsafeCxxMutableInputIterator)
PROTOCOL(UnsafeCxxRandomAccessIterator)
PROTOCOL(UnsafeCxxMutableRandomAccessIterator)
PROTOCOL(UnsafeCxxContiguousIterator)
PROTOCOL(UnsafeCxxMutableContiguousIterator)

PROTOCOL(AsyncSequence)
PROTOCOL(AsyncIteratorProtocol)
Expand Down
2 changes: 2 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,8 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
case KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator:
case KnownProtocolKind::UnsafeCxxContiguousIterator:
case KnownProtocolKind::UnsafeCxxMutableContiguousIterator:
M = getLoadedModule(Id_Cxx);
break;
case KnownProtocolKind::Copyable:
Expand Down
18 changes: 18 additions & 0 deletions lib/ClangImporter/ClangDerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,19 +462,28 @@ void swift::conformToCxxIteratorIfNeeded(
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
};
auto isContiguousIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "contiguous_iterator_tag"); // C++20
};

// Traverse all transitive bases of `underlyingDecl` to check if
// it inherits from `std::input_iterator_tag`.
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
bool isRandomAccessIterator =
isRandomAccessIteratorDecl(underlyingCategoryDecl);
bool isContiguousIterator = isContiguousIteratorDecl(underlyingCategoryDecl);
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
if (isInputIteratorDecl(base)) {
isInputIterator = true;
}
if (isRandomAccessIteratorDecl(base)) {
isRandomAccessIterator = true;
isInputIterator = true;
}
if (isContiguousIteratorDecl(base)) {
isContiguousIterator = true;
isRandomAccessIterator = true;
isInputIterator = true;
return false;
}
return true;
Expand Down Expand Up @@ -594,6 +603,15 @@ void swift::conformToCxxIteratorIfNeeded(
else
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});

if (isContiguousIterator) {
if (pointeeSettable)
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxMutableContiguousIterator});
else
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxContiguousIterator});
}
}

void swift::conformToCxxConvertibleToBoolIfNeeded(
Expand Down
2 changes: 2 additions & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6967,6 +6967,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
case KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator:
case KnownProtocolKind::UnsafeCxxContiguousIterator:
case KnownProtocolKind::UnsafeCxxMutableContiguousIterator:
case KnownProtocolKind::Executor:
case KnownProtocolKind::SerialExecutor:
case KnownProtocolKind::TaskExecutor:
Expand Down
12 changes: 12 additions & 0 deletions stdlib/public/Cxx/UnsafeCxxIterators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,15 @@ public protocol UnsafeCxxMutableRandomAccessIterator:
UnsafeCxxRandomAccessIterator, UnsafeCxxMutableInputIterator {}

extension UnsafeMutablePointer: UnsafeCxxMutableRandomAccessIterator {}

/// Bridged C++ iterator that allows traversing elements of a random access
/// collection that are stored in contiguous memory segments.
///
/// Mostly useful for optimizing operations with containers that conform to
/// `CxxRandomAccessCollection` and should not generally be used directly.
///
/// - SeeAlso: https://en.cppreference.com/w/cpp/named_req/ContiguousIterator
public protocol UnsafeCxxContiguousIterator: UnsafeCxxRandomAccessIterator {}

public protocol UnsafeCxxMutableContiguousIterator:
UnsafeCxxContiguousIterator, UnsafeCxxMutableRandomAccessIterator {}
279 changes: 223 additions & 56 deletions test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,229 @@ struct HasTypedefIteratorTag {
}
};

struct MutableRACIterator {
private:
int *value;

public:
struct iterator_category : std::random_access_iterator_tag,
std::output_iterator_tag {};
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;

MutableRACIterator(int *value) : value(value) {}
MutableRACIterator(const MutableRACIterator &other) = default;

const int &operator*() const { return *value; }
int &operator*() { return *value; }

MutableRACIterator &operator++() {
value++;
return *this;
}
MutableRACIterator operator++(int) {
auto tmp = MutableRACIterator(value);
value++;
return tmp;
}

void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
MutableRACIterator operator+(difference_type v) const {
return MutableRACIterator(value + v);
}
MutableRACIterator operator-(difference_type v) const {
return MutableRACIterator(value - v);
}
friend MutableRACIterator operator+(difference_type v,
const MutableRACIterator &it) {
return it + v;
}
int operator-(const MutableRACIterator &other) const {
return value - other.value;
}

bool operator<(const MutableRACIterator &other) const {
return value < other.value;
}

bool operator==(const MutableRACIterator &other) const {
return value == other.value;
}
bool operator!=(const MutableRACIterator &other) const {
return value != other.value;
}
};

#if __cplusplus >= 202002L
struct ConstContiguousIterator {
private:
const int *value;

public:
using iterator_category = std::contiguous_iterator_tag;
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;

ConstContiguousIterator(const int *value) : value(value) {}
ConstContiguousIterator(const ConstContiguousIterator &other) = default;

const int &operator*() const { return *value; }

ConstContiguousIterator &operator++() {
value++;
return *this;
}
ConstContiguousIterator operator++(int) {
auto tmp = ConstContiguousIterator(value);
value++;
return tmp;
}

void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
ConstContiguousIterator operator+(difference_type v) const {
return ConstContiguousIterator(value + v);
}
ConstContiguousIterator operator-(difference_type v) const {
return ConstContiguousIterator(value - v);
}
friend ConstContiguousIterator operator+(difference_type v,
const ConstContiguousIterator &it) {
return it + v;
}
int operator-(const ConstContiguousIterator &other) const {
return value - other.value;
}

bool operator<(const ConstContiguousIterator &other) const {
return value < other.value;
}

bool operator==(const ConstContiguousIterator &other) const {
return value == other.value;
}
bool operator!=(const ConstContiguousIterator &other) const {
return value != other.value;
}
};

struct HasCustomContiguousIteratorTag {
private:
const int *value;

public:
struct CustomTag : std::contiguous_iterator_tag {};
using iterator_category = CustomTag;
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;

HasCustomContiguousIteratorTag(const int *value) : value(value) {}
HasCustomContiguousIteratorTag(const HasCustomContiguousIteratorTag &other) =
default;

const int &operator*() const { return *value; }

HasCustomContiguousIteratorTag &operator++() {
value++;
return *this;
}
HasCustomContiguousIteratorTag operator++(int) {
auto tmp = HasCustomContiguousIteratorTag(value);
value++;
return tmp;
}

void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
HasCustomContiguousIteratorTag operator+(difference_type v) const {
return HasCustomContiguousIteratorTag(value + v);
}
HasCustomContiguousIteratorTag operator-(difference_type v) const {
return HasCustomContiguousIteratorTag(value - v);
}
friend HasCustomContiguousIteratorTag
operator+(difference_type v, const HasCustomContiguousIteratorTag &it) {
return it + v;
}
int operator-(const HasCustomContiguousIteratorTag &other) const {
return value - other.value;
}

bool operator<(const HasCustomContiguousIteratorTag &other) const {
return value < other.value;
}

bool operator==(const HasCustomContiguousIteratorTag &other) const {
return value == other.value;
}
bool operator!=(const HasCustomContiguousIteratorTag &other) const {
return value != other.value;
}
};

struct MutableContiguousIterator {
private:
int *value;

public:
using iterator_category = std::contiguous_iterator_tag;
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;

MutableContiguousIterator(int *value) : value(value) {}
MutableContiguousIterator(const MutableContiguousIterator &other) = default;

const int &operator*() const { return *value; }
int &operator*() { return *value; }

MutableContiguousIterator &operator++() {
value++;
return *this;
}
MutableContiguousIterator operator++(int) {
auto tmp = MutableContiguousIterator(value);
value++;
return tmp;
}

void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
MutableContiguousIterator operator+(difference_type v) const {
return MutableContiguousIterator(value + v);
}
MutableContiguousIterator operator-(difference_type v) const {
return MutableContiguousIterator(value - v);
}
friend MutableContiguousIterator
operator+(difference_type v, const MutableContiguousIterator &it) {
return it + v;
}
int operator-(const MutableContiguousIterator &other) const {
return value - other.value;
}

bool operator<(const MutableContiguousIterator &other) const {
return value < other.value;
}

bool operator==(const MutableContiguousIterator &other) const {
return value == other.value;
}
bool operator!=(const MutableContiguousIterator &other) const {
return value != other.value;
}
};
#endif

// MARK: Types that are not actually iterators

struct HasNoIteratorCategory {
Expand Down Expand Up @@ -916,62 +1139,6 @@ struct InputOutputConstIterator {
}
};

struct MutableRACIterator {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(had to move this elsewhere in the file with no source changes)

private:
int *value;

public:
struct iterator_category : std::random_access_iterator_tag,
std::output_iterator_tag {};
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;

MutableRACIterator(int *value) : value(value) {}
MutableRACIterator(const MutableRACIterator &other) = default;

const int &operator*() const { return *value; }
int &operator*() { return *value; }

MutableRACIterator &operator++() {
value++;
return *this;
}
MutableRACIterator operator++(int) {
auto tmp = MutableRACIterator(value);
value++;
return tmp;
}

void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
MutableRACIterator operator+(difference_type v) const {
return MutableRACIterator(value + v);
}
MutableRACIterator operator-(difference_type v) const {
return MutableRACIterator(value - v);
}
friend MutableRACIterator operator+(difference_type v,
const MutableRACIterator &it) {
return it + v;
}
int operator-(const MutableRACIterator &other) const {
return value - other.value;
}

bool operator<(const MutableRACIterator &other) const {
return value < other.value;
}

bool operator==(const MutableRACIterator &other) const {
return value == other.value;
}
bool operator!=(const MutableRACIterator &other) const {
return value != other.value;
}
};

/// clang::StmtIteratorBase
class ProtectedIteratorBase {
protected:
Expand Down
Loading