Skip to content

Commit e064d23

Browse files
authored
Merge pull request #66764 from apple/egorzhdan/cxx-set-insert
[cxx-interop] Allow inserting elements into `std::set` from Swift
2 parents 0f20144 + b79b65c commit e064d23

File tree

6 files changed

+85
-1
lines changed

6 files changed

+85
-1
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ PROTOCOL(CxxPair)
114114
PROTOCOL(CxxSet)
115115
PROTOCOL(CxxRandomAccessCollection)
116116
PROTOCOL(CxxSequence)
117+
PROTOCOL(CxxUniqueSet)
117118
PROTOCOL(UnsafeCxxInputIterator)
118119
PROTOCOL(UnsafeCxxRandomAccessIterator)
119120

lib/AST/ASTContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
11381138
case KnownProtocolKind::CxxRandomAccessCollection:
11391139
case KnownProtocolKind::CxxSet:
11401140
case KnownProtocolKind::CxxSequence:
1141+
case KnownProtocolKind::CxxUniqueSet:
11411142
case KnownProtocolKind::UnsafeCxxInputIterator:
11421143
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
11431144
M = getLoadedModule(Id_Cxx);

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,34 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
614614
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("InsertionResult"),
615615
insert->getResultInterfaceType());
616616
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSet});
617+
618+
// If this isn't a std::multiset, try to also synthesize the conformance to
619+
// CxxUniqueSet.
620+
if (!isStdDecl(clangDecl, {"set", "unordered_set"}))
621+
return;
622+
623+
ProtocolDecl *cxxIteratorProto =
624+
ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator);
625+
if (!cxxIteratorProto)
626+
return;
627+
628+
auto rawMutableIteratorType =
629+
lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
630+
decl, ctx.getIdentifier("iterator"));
631+
if (!rawMutableIteratorType)
632+
return;
633+
634+
auto rawMutableIteratorTy = rawMutableIteratorType->getUnderlyingType();
635+
// Check if RawMutableIterator conforms to UnsafeCxxInputIterator.
636+
ModuleDecl *module = decl->getModuleContext();
637+
auto rawIteratorConformanceRef =
638+
module->lookupConformance(rawMutableIteratorTy, cxxIteratorProto);
639+
if (!isConcreteAndValid(rawIteratorConformanceRef, module))
640+
return;
641+
642+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawMutableIterator"),
643+
rawMutableIteratorTy);
644+
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxUniqueSet});
617645
}
618646

619647
void swift::conformToCxxPairIfNeeded(ClangImporter::Implementation &impl,

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6288,6 +6288,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
62886288
case KnownProtocolKind::CxxRandomAccessCollection:
62896289
case KnownProtocolKind::CxxSet:
62906290
case KnownProtocolKind::CxxSequence:
6291+
case KnownProtocolKind::CxxUniqueSet:
62916292
case KnownProtocolKind::UnsafeCxxInputIterator:
62926293
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
62936294
case KnownProtocolKind::Executor:

stdlib/public/Cxx/CxxSet.swift

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
public protocol CxxSet<Element> {
1414
associatedtype Element
1515
associatedtype Size: BinaryInteger
16-
associatedtype InsertionResult // std::pair<iterator, bool>
16+
17+
// std::pair<iterator, bool> for std::set and std::unordered_set
18+
// iterator for std::multiset
19+
associatedtype InsertionResult
1720

1821
init()
1922

@@ -43,3 +46,25 @@ extension CxxSet {
4346
return count(element) > 0
4447
}
4548
}
49+
50+
public protocol CxxUniqueSet<Element>: CxxSet {
51+
override associatedtype Element
52+
override associatedtype Size: BinaryInteger
53+
associatedtype RawMutableIterator: UnsafeCxxInputIterator
54+
where RawMutableIterator.Pointee == Element
55+
override associatedtype InsertionResult
56+
where InsertionResult: CxxPair<RawMutableIterator, Bool>
57+
}
58+
59+
extension CxxUniqueSet {
60+
@inlinable
61+
@discardableResult
62+
public mutating func insert(
63+
_ element: Element
64+
) -> (inserted: Bool, memberAfterInsert: Element) {
65+
let insertionResult = self.__insertUnsafe(element)
66+
let rawIterator: RawMutableIterator = insertionResult.first
67+
let inserted: Bool = insertionResult.second
68+
return (inserted, rawIterator.pointee)
69+
}
70+
}

test/Interop/Cxx/stdlib/use-std-set.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,32 @@ StdSetTestSuite.test("UnorderedSetOfCInt.init()") {
6161
expectTrue(s.contains(3))
6262
}
6363

64+
StdSetTestSuite.test("SetOfCInt.insert") {
65+
var s = SetOfCInt()
66+
expectFalse(s.contains(123))
67+
68+
let res1 = s.insert(123)
69+
expectTrue(res1.inserted)
70+
expectTrue(s.contains(123))
71+
72+
let res2 = s.insert(123)
73+
expectFalse(res2.inserted)
74+
expectTrue(s.contains(123))
75+
}
76+
77+
#if !os(Linux) // FIXME: https://github.com/apple/swift/issues/66767 / rdar://105220600
78+
StdSetTestSuite.test("UnorderedSetOfCInt.insert") {
79+
var s = UnorderedSetOfCInt()
80+
expectFalse(s.contains(123))
81+
82+
let res1 = s.insert(123)
83+
expectTrue(res1.inserted)
84+
expectTrue(s.contains(123))
85+
86+
let res2 = s.insert(123)
87+
expectFalse(res2.inserted)
88+
expectTrue(s.contains(123))
89+
}
90+
#endif
91+
6492
runAllTests()

0 commit comments

Comments
 (0)