Skip to content

Commit 8fba23f

Browse files
authored
Merge pull request #67154 from apple/egorzhdan/5.9-cxx-set-insert
🍒[cxx-interop] Allow inserting elements into `std::set` from Swift
2 parents 3c56626 + 7211434 commit 8fba23f

File tree

6 files changed

+85
-1
lines changed

6 files changed

+85
-1
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ PROTOCOL(CxxPair)
114114
PROTOCOL(CxxSet)
115115
PROTOCOL(CxxRandomAccessCollection)
116116
PROTOCOL(CxxSequence)
117+
PROTOCOL(CxxUniqueSet)
117118
PROTOCOL(UnsafeCxxInputIterator)
118119
PROTOCOL(UnsafeCxxRandomAccessIterator)
119120

lib/AST/ASTContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
11281128
case KnownProtocolKind::CxxRandomAccessCollection:
11291129
case KnownProtocolKind::CxxSet:
11301130
case KnownProtocolKind::CxxSequence:
1131+
case KnownProtocolKind::CxxUniqueSet:
11311132
case KnownProtocolKind::UnsafeCxxInputIterator:
11321133
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
11331134
M = getLoadedModule(Id_Cxx);

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,34 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
631631
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("InsertionResult"),
632632
insert->getResultInterfaceType());
633633
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSet});
634+
635+
// If this isn't a std::multiset, try to also synthesize the conformance to
636+
// CxxUniqueSet.
637+
if (!isStdDecl(clangDecl, {"set", "unordered_set"}))
638+
return;
639+
640+
ProtocolDecl *cxxIteratorProto =
641+
ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator);
642+
if (!cxxIteratorProto)
643+
return;
644+
645+
auto rawMutableIteratorType =
646+
lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
647+
decl, ctx.getIdentifier("iterator"));
648+
if (!rawMutableIteratorType)
649+
return;
650+
651+
auto rawMutableIteratorTy = rawMutableIteratorType->getUnderlyingType();
652+
// Check if RawMutableIterator conforms to UnsafeCxxInputIterator.
653+
ModuleDecl *module = decl->getModuleContext();
654+
auto rawIteratorConformanceRef =
655+
module->lookupConformance(rawMutableIteratorTy, cxxIteratorProto);
656+
if (!isConcreteAndValid(rawIteratorConformanceRef, module))
657+
return;
658+
659+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawMutableIterator"),
660+
rawMutableIteratorTy);
661+
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxUniqueSet});
634662
}
635663

636664
void swift::conformToCxxPairIfNeeded(ClangImporter::Implementation &impl,

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6272,6 +6272,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
62726272
case KnownProtocolKind::CxxRandomAccessCollection:
62736273
case KnownProtocolKind::CxxSet:
62746274
case KnownProtocolKind::CxxSequence:
6275+
case KnownProtocolKind::CxxUniqueSet:
62756276
case KnownProtocolKind::UnsafeCxxInputIterator:
62766277
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
62776278
case KnownProtocolKind::Executor:

stdlib/public/Cxx/CxxSet.swift

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
public protocol CxxSet<Element> {
1414
associatedtype Element
1515
associatedtype Size: BinaryInteger
16-
associatedtype InsertionResult // std::pair<iterator, bool>
16+
17+
// std::pair<iterator, bool> for std::set and std::unordered_set
18+
// iterator for std::multiset
19+
associatedtype InsertionResult
1720

1821
init()
1922

@@ -43,3 +46,25 @@ extension CxxSet {
4346
return count(element) > 0
4447
}
4548
}
49+
50+
public protocol CxxUniqueSet<Element>: CxxSet {
51+
override associatedtype Element
52+
override associatedtype Size: BinaryInteger
53+
associatedtype RawMutableIterator: UnsafeCxxInputIterator
54+
where RawMutableIterator.Pointee == Element
55+
override associatedtype InsertionResult
56+
where InsertionResult: CxxPair<RawMutableIterator, Bool>
57+
}
58+
59+
extension CxxUniqueSet {
60+
@inlinable
61+
@discardableResult
62+
public mutating func insert(
63+
_ element: Element
64+
) -> (inserted: Bool, memberAfterInsert: Element) {
65+
let insertionResult = self.__insertUnsafe(element)
66+
let rawIterator: RawMutableIterator = insertionResult.first
67+
let inserted: Bool = insertionResult.second
68+
return (inserted, rawIterator.pointee)
69+
}
70+
}

test/Interop/Cxx/stdlib/use-std-set.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,32 @@ StdSetTestSuite.test("UnorderedSetOfCInt.init()") {
6161
expectTrue(s.contains(3))
6262
}
6363

64+
StdSetTestSuite.test("SetOfCInt.insert") {
65+
var s = SetOfCInt()
66+
expectFalse(s.contains(123))
67+
68+
let res1 = s.insert(123)
69+
expectTrue(res1.inserted)
70+
expectTrue(s.contains(123))
71+
72+
let res2 = s.insert(123)
73+
expectFalse(res2.inserted)
74+
expectTrue(s.contains(123))
75+
}
76+
77+
#if !os(Linux) // FIXME: https://github.com/apple/swift/issues/66767 / rdar://105220600
78+
StdSetTestSuite.test("UnorderedSetOfCInt.insert") {
79+
var s = UnorderedSetOfCInt()
80+
expectFalse(s.contains(123))
81+
82+
let res1 = s.insert(123)
83+
expectTrue(res1.inserted)
84+
expectTrue(s.contains(123))
85+
86+
let res2 = s.insert(123)
87+
expectFalse(res2.inserted)
88+
expectTrue(s.contains(123))
89+
}
90+
#endif
91+
6492
runAllTests()

0 commit comments

Comments
 (0)