Skip to content

Commit 713a0a0

Browse files
authored
Merge pull request #67623 from apple/egorzhdan/std-map-mutable-subscript
[cxx-interop] Allow mutating `std::map` from Swift
2 parents 62a986e + 6caaa77 commit 713a0a0

File tree

5 files changed

+146
-23
lines changed

5 files changed

+146
-23
lines changed

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,31 @@ static bool isConcreteAndValid(ProtocolConformanceRef conformanceRef,
8686
});
8787
}
8888

89+
static FuncDecl *getInsertFunc(NominalTypeDecl *decl,
90+
TypeAliasDecl *valueType) {
91+
ASTContext &ctx = decl->getASTContext();
92+
93+
auto insertId = ctx.getIdentifier("__insertUnsafe");
94+
auto inserts = lookupDirectWithoutExtensions(decl, insertId);
95+
FuncDecl *insert = nullptr;
96+
for (auto candidate : inserts) {
97+
if (auto candidateMethod = dyn_cast<FuncDecl>(candidate)) {
98+
if (!candidateMethod->hasParameterList())
99+
continue;
100+
auto params = candidateMethod->getParameters();
101+
if (params->size() != 1)
102+
continue;
103+
auto param = params->front();
104+
if (param->getType()->getCanonicalType() !=
105+
valueType->getUnderlyingType()->getCanonicalType())
106+
continue;
107+
insert = candidateMethod;
108+
break;
109+
}
110+
}
111+
return insert;
112+
}
113+
89114
static bool isStdDecl(const clang::CXXRecordDecl *clangDecl,
90115
llvm::ArrayRef<StringRef> names) {
91116
if (!clangDecl->isInStdNamespace())
@@ -713,12 +738,16 @@ static bool isStdSetType(const clang::CXXRecordDecl *clangDecl) {
713738
return isStdDecl(clangDecl, {"set", "unordered_set", "multiset"});
714739
}
715740

741+
static bool isStdMapType(const clang::CXXRecordDecl *clangDecl) {
742+
return isStdDecl(clangDecl, {"map", "unordered_map", "multimap"});
743+
}
744+
716745
bool swift::isUnsafeStdMethod(const clang::CXXMethodDecl *methodDecl) {
717746
auto parentDecl =
718747
dyn_cast<clang::CXXRecordDecl>(methodDecl->getDeclContext());
719748
if (!parentDecl)
720749
return false;
721-
if (!isStdSetType(parentDecl))
750+
if (!isStdSetType(parentDecl) && !isStdMapType(parentDecl))
722751
return false;
723752
if (methodDecl->getDeclName().isIdentifier() &&
724753
methodDecl->getName() == "insert")
@@ -747,24 +776,7 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
747776
if (!valueType || !sizeType)
748777
return;
749778

750-
auto insertId = ctx.getIdentifier("__insertUnsafe");
751-
auto inserts = lookupDirectWithoutExtensions(decl, insertId);
752-
FuncDecl *insert = nullptr;
753-
for (auto candidate : inserts) {
754-
if (auto candidateMethod = dyn_cast<FuncDecl>(candidate)) {
755-
if (!candidateMethod->hasParameterList())
756-
continue;
757-
auto params = candidateMethod->getParameters();
758-
if (params->size() != 1)
759-
continue;
760-
auto param = params->front();
761-
if (param->getType()->getCanonicalType() !=
762-
valueType->getUnderlyingType()->getCanonicalType())
763-
continue;
764-
insert = candidateMethod;
765-
break;
766-
}
767-
}
779+
auto insert = getInsertFunc(decl, valueType);
768780
if (!insert)
769781
return;
770782

@@ -844,7 +856,7 @@ void swift::conformToCxxDictionaryIfNeeded(
844856

845857
// Only auto-conform types from the C++ standard library. Custom user types
846858
// might have a similar interface but different semantics.
847-
if (!isStdDecl(clangDecl, {"map", "unordered_map"}))
859+
if (!isStdMapType(clangDecl))
848860
return;
849861

850862
auto keyType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
@@ -853,7 +865,41 @@ void swift::conformToCxxDictionaryIfNeeded(
853865
decl, ctx.getIdentifier("mapped_type"));
854866
auto iterType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
855867
decl, ctx.getIdentifier("const_iterator"));
856-
if (!keyType || !valueType || !iterType)
868+
auto mutableIterType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
869+
decl, ctx.getIdentifier("iterator"));
870+
auto sizeType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
871+
decl, ctx.getIdentifier("size_type"));
872+
auto keyValuePairType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
873+
decl, ctx.getIdentifier("value_type"));
874+
if (!keyType || !valueType || !iterType || !mutableIterType || !sizeType ||
875+
!keyValuePairType)
876+
return;
877+
878+
auto insert = getInsertFunc(decl, keyValuePairType);
879+
if (!insert)
880+
return;
881+
882+
ProtocolDecl *cxxInputIteratorProto =
883+
ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator);
884+
ProtocolDecl *cxxMutableInputIteratorProto =
885+
ctx.getProtocol(KnownProtocolKind::UnsafeCxxMutableInputIterator);
886+
if (!cxxInputIteratorProto || !cxxMutableInputIteratorProto)
887+
return;
888+
889+
auto rawIteratorTy = iterType->getUnderlyingType();
890+
auto rawMutableIteratorTy = mutableIterType->getUnderlyingType();
891+
892+
// Check if RawIterator conforms to UnsafeCxxInputIterator.
893+
ModuleDecl *module = decl->getModuleContext();
894+
auto rawIteratorConformanceRef =
895+
module->lookupConformance(rawIteratorTy, cxxInputIteratorProto);
896+
if (!isConcreteAndValid(rawIteratorConformanceRef, module))
897+
return;
898+
899+
// Check if RawMutableIterator conforms to UnsafeCxxMutableInputIterator.
900+
auto rawMutableIteratorConformanceRef = module->lookupConformance(
901+
rawMutableIteratorTy, cxxMutableInputIteratorProto);
902+
if (!isConcreteAndValid(rawMutableIteratorConformanceRef, module))
857903
return;
858904

859905
// Make the original subscript that returns a non-optional value unavailable.
@@ -869,7 +915,15 @@ void swift::conformToCxxDictionaryIfNeeded(
869915
impl.addSynthesizedTypealias(decl, ctx.Id_Key, keyType->getUnderlyingType());
870916
impl.addSynthesizedTypealias(decl, ctx.Id_Value,
871917
valueType->getUnderlyingType());
918+
impl.addSynthesizedTypealias(decl, ctx.Id_Element,
919+
keyValuePairType->getUnderlyingType());
872920
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"),
873-
iterType->getUnderlyingType());
921+
rawIteratorTy);
922+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawMutableIterator"),
923+
rawMutableIteratorTy);
924+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Size"),
925+
sizeType->getUnderlyingType());
926+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("InsertionResult"),
927+
insert->getResultInterfaceType());
874928
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxDictionary});
875929
}

stdlib/public/Cxx/CxxDictionary.swift

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,33 @@
1717
public protocol CxxDictionary<Key, Value> {
1818
associatedtype Key
1919
associatedtype Value
20+
associatedtype Element: CxxPair<Key, Value>
2021
associatedtype RawIterator: UnsafeCxxInputIterator
21-
where RawIterator.Pointee: CxxPair<Key, Value>
22+
where RawIterator.Pointee == Element
23+
associatedtype RawMutableIterator: UnsafeCxxMutableInputIterator
24+
where RawMutableIterator.Pointee == Element
25+
associatedtype Size: BinaryInteger
26+
associatedtype InsertionResult
2227

2328
/// Do not implement this function manually in Swift.
2429
func __findUnsafe(_ key: Key) -> RawIterator
2530

31+
/// Do not implement this function manually in Swift.
32+
mutating func __findMutatingUnsafe(_ key: Key) -> RawMutableIterator
33+
34+
/// Do not implement this function manually in Swift.
35+
@discardableResult
36+
mutating func __insertUnsafe(_ element: Element) -> InsertionResult
37+
38+
/// Do not implement this function manually in Swift.
39+
@discardableResult
40+
mutating func erase(_ key: Key) -> Size
41+
2642
/// Do not implement this function manually in Swift.
2743
func __endUnsafe() -> RawIterator
44+
45+
/// Do not implement this function manually in Swift.
46+
mutating func __endMutatingUnsafe() -> RawMutableIterator
2847
}
2948

3049
extension CxxDictionary {
@@ -37,5 +56,20 @@ extension CxxDictionary {
3756
}
3857
return iter.pointee.second
3958
}
59+
set(newValue) {
60+
guard let newValue = newValue else {
61+
self.erase(key)
62+
return
63+
}
64+
var iter = self.__findMutatingUnsafe(key)
65+
if iter != self.__endMutatingUnsafe() {
66+
// This key already exists in the dictionary.
67+
iter.pointee.second = newValue
68+
} else {
69+
// Create a std::pair<key_type, mapped_type>.
70+
let keyValuePair = Element(first: key, second: newValue)
71+
self.__insertUnsafe(keyValuePair)
72+
}
73+
}
4074
}
4175
}

stdlib/public/Cxx/CxxPair.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ public protocol CxxPair<First, Second> {
1717
associatedtype First
1818
associatedtype Second
1919

20+
init(first: First, second: Second) // memberwise init, synthesized by Swift
21+
2022
var first: First { get set }
2123
var second: Second { get set }
2224
}

test/Interop/Cxx/stdlib/Inputs/std-map.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
#include <map>
55
#include <unordered_map>
6+
#include <string>
67

78
using Map = std::map<int, int>;
9+
using MapStrings = std::map<std::string, std::string>;
810
using UnorderedMap = std::unordered_map<int, int>;
911

1012
inline Map initMap() { return {{1, 3}, {2, 2}, {3, 3}}; }

test/Interop/Cxx/stdlib/use-std-map.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@ StdMapTestSuite.test("Map.subscript") {
2727
expectEqual(m[3], 3)
2828
expectNil(m[-1])
2929
expectNil(m[5])
30+
31+
m[1] = 111
32+
expectEqual(m[1], 111)
33+
34+
m[5] = 555
35+
expectEqual(m[5], 555)
36+
37+
m[5] = nil
38+
expectNil(m[5])
39+
expectNil(m[5])
40+
}
41+
42+
StdMapTestSuite.test("MapStrings.subscript") {
43+
var m = MapStrings()
44+
expectNil(m[std.string()])
45+
expectNil(m[std.string()])
46+
m[std.string()] = std.string()
47+
expectNotNil(m[std.string()])
48+
49+
m[std.string("abc")] = std.string("qwe")
50+
expectEqual(m[std.string("abc")], std.string("qwe"))
3051
}
3152

3253
StdMapTestSuite.test("UnorderedMap.subscript") {
@@ -37,6 +58,16 @@ StdMapTestSuite.test("UnorderedMap.subscript") {
3758
expectEqual(m[3], 3)
3859
expectNil(m[-1])
3960
expectNil(m[5])
61+
62+
m[1] = 777
63+
expectEqual(m[1], 777)
64+
65+
m[-1] = 228
66+
expectEqual(m[-1], 228)
67+
68+
m[-1] = nil
69+
expectNil(m[-1])
70+
expectNil(m[-1])
4071
}
4172

4273
StdMapTestSuite.test("Map.erase") {

0 commit comments

Comments
 (0)