Skip to content

Commit f88b29b

Browse files
authored
Merge pull request #77445 from CrazyFanFan/feature_cxxset_remove
[cxx-interop] Allow removing elements from `std::set`.
2 parents f802b67 + aeaa8ec commit f88b29b

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -928,22 +928,28 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
928928
if (!isStdDecl(clangDecl, {"set", "unordered_set"}))
929929
return;
930930

931-
ProtocolDecl *cxxIteratorProto =
931+
ProtocolDecl *cxxInputIteratorProto =
932932
ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator);
933-
if (!cxxIteratorProto)
933+
if (!cxxInputIteratorProto)
934934
return;
935935

936+
auto rawIteratorType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
937+
decl, ctx.getIdentifier("const_iterator"));
936938
auto rawMutableIteratorType =
937939
lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
938940
decl, ctx.getIdentifier("iterator"));
939-
if (!rawMutableIteratorType)
941+
if (!rawIteratorType || !rawMutableIteratorType)
940942
return;
941943

944+
auto rawIteratorTy = rawIteratorType->getUnderlyingType();
942945
auto rawMutableIteratorTy = rawMutableIteratorType->getUnderlyingType();
943-
// Check if RawMutableIterator conforms to UnsafeCxxInputIterator.
944-
if (!checkConformance(rawMutableIteratorTy, cxxIteratorProto))
946+
947+
if (!checkConformance(rawIteratorTy, cxxInputIteratorProto) ||
948+
!checkConformance(rawMutableIteratorTy, cxxInputIteratorProto))
945949
return;
946950

951+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"),
952+
rawIteratorTy);
947953
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawMutableIterator"),
948954
rawMutableIteratorTy);
949955
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxUniqueSet});

stdlib/public/Cxx/CxxSet.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,21 @@ extension CxxSet {
6767
public protocol CxxUniqueSet<Element>: CxxSet {
6868
override associatedtype Element
6969
override associatedtype Size: BinaryInteger
70+
associatedtype RawIterator: UnsafeCxxInputIterator
71+
where RawIterator.Pointee == Element
7072
associatedtype RawMutableIterator: UnsafeCxxInputIterator
7173
where RawMutableIterator.Pointee == Element
7274
override associatedtype InsertionResult
7375
where InsertionResult: CxxPair<RawMutableIterator, Bool>
76+
77+
@discardableResult
78+
mutating func __findUnsafe(_ value: Element) -> RawIterator
79+
80+
@discardableResult
81+
mutating func __eraseUnsafe(_ iter: RawIterator) -> RawMutableIterator
82+
83+
@discardableResult
84+
mutating func __endUnsafe() -> RawIterator
7485
}
7586

7687
extension CxxUniqueSet {
@@ -94,4 +105,19 @@ extension CxxUniqueSet {
94105
let inserted: Bool = insertionResult.second
95106
return (inserted, rawIterator.pointee)
96107
}
108+
109+
/// Removes the given element from the set.
110+
///
111+
/// - Parameter member: An element to remove from the set.
112+
@discardableResult
113+
@inlinable
114+
public mutating func remove(_ member: Element) -> Element? {
115+
let iter = self.__findUnsafe(member)
116+
guard iter != self.__endUnsafe() else {
117+
return nil
118+
}
119+
let value = iter.pointee
120+
self.__eraseUnsafe(iter)
121+
return value
122+
}
97123
}

test/Interop/Cxx/stdlib/use-std-set.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,22 @@ StdSetTestSuite.test("UnorderedSetOfCInt.erase") {
164164
expectFalse(s.contains(2))
165165
}
166166

167+
StdSetTestSuite.test("SetOfCInt.remove") {
168+
var s = initSetOfCInt()
169+
expectTrue(s.contains(1))
170+
expectEqual(s.remove(1), 1)
171+
expectFalse(s.contains(1))
172+
expectEqual(s.remove(1), nil)
173+
expectFalse(s.contains(1))
174+
}
175+
176+
StdSetTestSuite.test("UnorderedSetOfCInt.remove") {
177+
var s = initUnorderedSetOfCInt()
178+
expectTrue(s.contains(2))
179+
expectEqual(s.remove(2), 2)
180+
expectFalse(s.contains(2))
181+
expectEqual(s.remove(2), nil)
182+
expectFalse(s.contains(2))
183+
}
184+
167185
runAllTests()

0 commit comments

Comments
 (0)