Skip to content

[cxx-interop] Add CxxMutableRandomAccessCollection protocol #76106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ PROTOCOL(CxxOptional)
PROTOCOL(CxxPair)
PROTOCOL(CxxSet)
PROTOCOL(CxxRandomAccessCollection)
PROTOCOL(CxxMutableRandomAccessCollection)
PROTOCOL(CxxSequence)
PROTOCOL(CxxUniqueSet)
PROTOCOL(CxxVector)
Expand Down
1 change: 1 addition & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::CxxPair:
case KnownProtocolKind::CxxOptional:
case KnownProtocolKind::CxxRandomAccessCollection:
case KnownProtocolKind::CxxMutableRandomAccessCollection:
case KnownProtocolKind::CxxSet:
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::CxxUniqueSet:
Expand Down
40 changes: 38 additions & 2 deletions lib/ClangImporter/ClangDerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,44 @@ void swift::conformToCxxSequenceIfNeeded(
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Indices"), indicesTy);
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("SubSequence"),
sliceTy);
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::CxxRandomAccessCollection});

auto tryToConformToMutatingRACollection = [&]() -> bool {
auto rawMutableIteratorProto = ctx.getProtocol(
KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator);
if (!rawMutableIteratorProto)
return false;

// Check if present: `func __beginMutatingUnsafe() -> RawMutableIterator`
auto beginMutatingId = ctx.getIdentifier("__beginMutatingUnsafe");
auto beginMutating =
lookupDirectSingleWithoutExtensions<FuncDecl>(decl, beginMutatingId);
if (!beginMutating)
return false;
auto rawMutableIteratorTy = beginMutating->getResultInterfaceType();

// Check if present: `func __endMutatingUnsafe() -> RawMutableIterator`
auto endMutatingId = ctx.getIdentifier("__endMutatingUnsafe");
auto endMutating =
lookupDirectSingleWithoutExtensions<FuncDecl>(decl, endMutatingId);
if (!endMutating)
return false;

if (!checkConformance(rawMutableIteratorTy, rawMutableIteratorProto))
return false;

impl.addSynthesizedTypealias(
decl, ctx.getIdentifier("RawMutableIterator"), rawMutableIteratorTy);
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::CxxMutableRandomAccessCollection});
return true;
};

bool conformedToMutableRAC = tryToConformToMutatingRACollection();

if (!conformedToMutableRAC)
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::CxxRandomAccessCollection});

return true;
};

Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6883,6 +6883,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::CxxPair:
case KnownProtocolKind::CxxOptional:
case KnownProtocolKind::CxxRandomAccessCollection:
case KnownProtocolKind::CxxMutableRandomAccessCollection:
case KnownProtocolKind::CxxSet:
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::CxxUniqueSet:
Expand Down
46 changes: 41 additions & 5 deletions stdlib/public/Cxx/CxxRandomAccessCollection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,52 @@ extension CxxRandomAccessCollection {
return Int(__endUnsafe() - __beginUnsafe())
}

@inlinable
@inline(__always)
internal func _getRawIterator(at index: Int) -> RawIterator {
var rawIterator = self.__beginUnsafe()
rawIterator += RawIterator.Distance(index)
precondition(self.__endUnsafe() - rawIterator > 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the plan for these when hardening for libcxx is on? Are we relying on the optimizer to remove the redundant checks or do we have a way to disable precondition checks when hardening is turned on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@susmonteiro might know more details here, but the general idea is that there won't be a bounds-check here with hardening enabled, unless you also enable libc++ bounded iterators. This is because we're not using the C++ operator[] – this function would only be invoked if the C++ type does not have a suitable overload of operator[].

We don't currently have a good way to check if bounded iterators are enabled and provide an alternative implementation, so for now we're just hoping that the optimizer removes these checks, I think.

"C++ iterator access out of bounds")
return rawIterator
}

/// A C++ implementation of the subscript might be more performant. This
/// overload should only be used if the C++ type does not define `operator[]`.
@inlinable
public subscript(_ index: Int) -> Element {
_read {
// Not using CxxIterator here to avoid making a copy of the collection.
var rawIterator = __beginUnsafe()
rawIterator += RawIterator.Distance(index)
precondition(__endUnsafe() - rawIterator > 0, "C++ iterator access out of bounds")
yield rawIterator.pointee
yield self._getRawIterator(at: index).pointee
}
}
}

public protocol CxxMutableRandomAccessCollection<Element>:
CxxRandomAccessCollection, MutableCollection {
associatedtype RawMutableIterator: UnsafeCxxMutableRandomAccessIterator
where RawMutableIterator.Pointee == Element

/// Do not implement this function manually in Swift.
mutating func __beginMutatingUnsafe() -> RawMutableIterator

/// Do not implement this function manually in Swift.
mutating func __endMutatingUnsafe() -> RawMutableIterator
}

extension CxxMutableRandomAccessCollection {
/// A C++ implementation of the subscript might be more performant. This
/// overload should only be used if the C++ type does not define `operator[]`.
@inlinable
public subscript(_ index: Int) -> Element {
_read {
yield self._getRawIterator(at: index).pointee
}
_modify {
var rawIterator = self.__beginMutatingUnsafe()
rawIterator += RawMutableIterator.Distance(index)
precondition(self.__endMutatingUnsafe() - rawIterator > 0,
"C++ iterator access out of bounds")
yield &rawIterator.pointee
}
}
}
4 changes: 2 additions & 2 deletions test/Interop/Cxx/stdlib/libcxx-module-interface.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

// CHECK-IOSFWD: enum std {
// CHECK-IOSFWD: enum __1 {
// CHECK-IOSFWD: struct basic_string<CChar, char_traits<CChar>, allocator<CChar>> : CxxRandomAccessCollection {
// CHECK-IOSFWD: struct basic_string<CChar, char_traits<CChar>, allocator<CChar>> : CxxMutableRandomAccessCollection {
// CHECK-IOSFWD: typealias value_type = CChar
// CHECK-IOSFWD: }
// CHECK-IOSFWD: struct basic_string<CWideChar, char_traits<CWideChar>, allocator<CWideChar>> : CxxRandomAccessCollection {
// CHECK-IOSFWD: struct basic_string<CWideChar, char_traits<CWideChar>, allocator<CWideChar>> : CxxMutableRandomAccessCollection {
// CHECK-IOSFWD: typealias value_type = CWideChar
// CHECK-IOSFWD: }
// CHECK-IOSFWD: typealias string = std.__1.basic_string<CChar, char_traits<CChar>, allocator<CChar>>
Expand Down
4 changes: 2 additions & 2 deletions test/Interop/Cxx/stdlib/libstdcxx-module-interface.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
// REQUIRES: OS=linux-gnu

// CHECK-STD: enum std {
// CHECK-STRING: struct basic_string<CChar, char_traits<CChar>, allocator<CChar>> : CxxRandomAccessCollection {
// CHECK-STRING: struct basic_string<CChar, char_traits<CChar>, allocator<CChar>> : CxxMutableRandomAccessCollection {
// CHECK-STRING: typealias value_type = std.char_traits<CChar>.char_type
// CHECK-STRING: }
// CHECK-STRING: struct basic_string<CWideChar, char_traits<CWideChar>, allocator<CWideChar>> : CxxRandomAccessCollection {
// CHECK-STRING: struct basic_string<CWideChar, char_traits<CWideChar>, allocator<CWideChar>> : CxxMutableRandomAccessCollection {
// CHECK-STRING: typealias value_type = std.char_traits<CWideChar>.char_type
// CHECK-STRING: }

Expand Down
25 changes: 21 additions & 4 deletions test/Interop/Cxx/stdlib/overlay/Inputs/custom-collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ struct SimpleCollectionNoSubscript {
public:
using iterator = ConstRACIterator;

iterator begin() const { return iterator(*x); }
iterator end() const { return iterator(*x + 5); }
iterator begin() const { return iterator(x); }
iterator end() const { return iterator(x + 5); }
};

struct SimpleCollectionReadOnly {
Expand All @@ -22,12 +22,29 @@ struct SimpleCollectionReadOnly {
public:
using iterator = ConstRACIteratorRefPlusEq;

iterator begin() const { return iterator(*x); }
iterator end() const { return iterator(*x + 5); }
iterator begin() const { return iterator(x); }
iterator end() const { return iterator(x + 5); }

const int& operator[](int index) const { return x[index]; }
};

struct SimpleCollectionReadWrite {
private:
int x[5] = {1, 2, 3, 4, 5};

public:
using const_iterator = ConstRACIterator;
using iterator = MutableRACIterator;

const_iterator begin() const { return const_iterator(x); }
const_iterator end() const { return const_iterator(x + 5); }
iterator begin() { return iterator(x); }
iterator end() { return iterator(x + 5); }

const int &operator[](int index) const { return x[index]; }
int &operator[](int index) { return x[index]; }
};

template <typename T>
struct HasInheritedTemplatedConstRACIterator {
public:
Expand Down
20 changes: 10 additions & 10 deletions test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct ConstIterator {

struct ConstRACIterator {
private:
int value;
const int *value;

public:
using iterator_category = std::random_access_iterator_tag;
Expand All @@ -51,10 +51,10 @@ struct ConstRACIterator {
using reference = const int &;
using difference_type = int;

ConstRACIterator(int value) : value(value) {}
ConstRACIterator(const int *value) : value(value) {}
ConstRACIterator(const ConstRACIterator &other) = default;

const int &operator*() const { return value; }
const int &operator*() const { return *value; }

ConstRACIterator &operator++() {
value++;
Expand Down Expand Up @@ -97,7 +97,7 @@ struct ConstRACIterator {
// Same as ConstRACIterator, but operator+= returns a reference to this.
struct ConstRACIteratorRefPlusEq {
private:
int value;
const int *value;

public:
using iterator_category = std::random_access_iterator_tag;
Expand All @@ -106,10 +106,10 @@ struct ConstRACIteratorRefPlusEq {
using reference = const int &;
using difference_type = int;

ConstRACIteratorRefPlusEq(int value) : value(value) {}
ConstRACIteratorRefPlusEq(const int *value) : value(value) {}
ConstRACIteratorRefPlusEq(const ConstRACIteratorRefPlusEq &other) = default;

const int &operator*() const { return value; }
const int &operator*() const { return *value; }

ConstRACIteratorRefPlusEq &operator++() {
value++;
Expand Down Expand Up @@ -918,7 +918,7 @@ struct InputOutputConstIterator {

struct MutableRACIterator {
private:
int value;
int *value;

public:
struct iterator_category : std::random_access_iterator_tag,
Expand All @@ -928,11 +928,11 @@ struct MutableRACIterator {
using reference = const int &;
using difference_type = int;

MutableRACIterator(int value) : value(value) {}
MutableRACIterator(int *value) : value(value) {}
MutableRACIterator(const MutableRACIterator &other) = default;

const int &operator*() const { return value; }
int &operator*() { return value; }
const int &operator*() const { return *value; }
int &operator*() { return *value; }

MutableRACIterator &operator++() {
value++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
// CHECK: typealias RawIterator = SimpleCollectionReadOnly.iterator
// CHECK: }

// CHECK: struct SimpleCollectionReadWrite : CxxMutableRandomAccessCollection {
// CHECK: typealias Element = ConstRACIterator.Pointee
// CHECK: typealias Iterator = CxxIterator<SimpleCollectionReadWrite>
// CHECK: typealias RawIterator = SimpleCollectionReadWrite.const_iterator
// CHECK: typealias RawMutableIterator = SimpleCollectionReadWrite.iterator
// CHECK: }

// CHECK: struct HasInheritedTemplatedConstRACIterator<CInt> : CxxRandomAccessCollection {
// CHECK: typealias Element = InheritedTemplatedConstRACIterator<CInt>.Pointee
// CHECK: typealias Iterator = CxxIterator<HasInheritedTemplatedConstRACIterator<CInt>>
Expand Down
17 changes: 17 additions & 0 deletions test/Interop/Cxx/stdlib/overlay/custom-collection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ CxxCollectionTestSuite.test("SimpleCollectionReadOnly as Swift.Collection") {
expectEqual(slice.last, 3)
}

CxxCollectionTestSuite.test("SimpleCollectionReadWrite as Swift.MutableCollection") {
var c = SimpleCollectionReadWrite()
expectEqual(c.first, 1)
expectEqual(c.last, 5)

c.swapAt(0, 4)
expectEqual(c.first, 5)
expectEqual(c.last, 1)

c.reverse()
expectEqual(c[0], 1)
expectEqual(c[1], 4)
expectEqual(c[2], 3)
expectEqual(c[3], 2)
expectEqual(c[4], 5)
}

CxxCollectionTestSuite.test("SimpleArrayWrapper as Swift.Collection") {
let c = SimpleArrayWrapper()
expectEqual(c.first, 10)
Expand Down
30 changes: 30 additions & 0 deletions test/Interop/Cxx/stdlib/use-std-vector.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ StdVectorTestSuite.test("VectorOfInt as ExpressibleByArrayLiteral") {
expectEqual(v2[2], 3)
}

#if !os(Windows) // FIXME: rdar://113704853
StdVectorTestSuite.test("VectorOfInt as MutableCollection") {
var v = Vector([2, 3, 1])
v.sort() // Swift function
expectEqual(v[0], 1)
expectEqual(v[1], 2)
expectEqual(v[2], 3)

v.reverse() // Swift function
expectEqual(v[0], 3)
expectEqual(v[1], 2)
expectEqual(v[2], 1)
}

StdVectorTestSuite.test("VectorOfString as MutableCollection") {
var v = VectorOfString([std.string("xyz"),
std.string("abc"),
std.string("ijk")])
v.swapAt(0, 2) // Swift function
expectEqual(v[0], std.string("ijk"))
expectEqual(v[1], std.string("abc"))
expectEqual(v[2], std.string("xyz"))

v.reverse() // Swift function
expectEqual(v[0], std.string("xyz"))
expectEqual(v[1], std.string("abc"))
expectEqual(v[2], std.string("ijk"))
}
#endif

StdVectorTestSuite.test("VectorOfInt.push_back") {
var v = Vector()
let _42: CInt = 42
Expand Down