@@ -86,6 +86,31 @@ static bool isConcreteAndValid(ProtocolConformanceRef conformanceRef,
86
86
});
87
87
}
88
88
89
+ static FuncDecl *getInsertFunc (NominalTypeDecl *decl,
90
+ TypeAliasDecl *valueType) {
91
+ ASTContext &ctx = decl->getASTContext ();
92
+
93
+ auto insertId = ctx.getIdentifier (" __insertUnsafe" );
94
+ auto inserts = lookupDirectWithoutExtensions (decl, insertId);
95
+ FuncDecl *insert = nullptr ;
96
+ for (auto candidate : inserts) {
97
+ if (auto candidateMethod = dyn_cast<FuncDecl>(candidate)) {
98
+ if (!candidateMethod->hasParameterList ())
99
+ continue ;
100
+ auto params = candidateMethod->getParameters ();
101
+ if (params->size () != 1 )
102
+ continue ;
103
+ auto param = params->front ();
104
+ if (param->getType ()->getCanonicalType () !=
105
+ valueType->getUnderlyingType ()->getCanonicalType ())
106
+ continue ;
107
+ insert = candidateMethod;
108
+ break ;
109
+ }
110
+ }
111
+ return insert;
112
+ }
113
+
89
114
static bool isStdDecl (const clang::CXXRecordDecl *clangDecl,
90
115
llvm::ArrayRef<StringRef> names) {
91
116
if (!clangDecl->isInStdNamespace ())
@@ -713,12 +738,16 @@ static bool isStdSetType(const clang::CXXRecordDecl *clangDecl) {
713
738
return isStdDecl (clangDecl, {" set" , " unordered_set" , " multiset" });
714
739
}
715
740
741
+ static bool isStdMapType (const clang::CXXRecordDecl *clangDecl) {
742
+ return isStdDecl (clangDecl, {" map" , " unordered_map" , " multimap" });
743
+ }
744
+
716
745
bool swift::isUnsafeStdMethod (const clang::CXXMethodDecl *methodDecl) {
717
746
auto parentDecl =
718
747
dyn_cast<clang::CXXRecordDecl>(methodDecl->getDeclContext ());
719
748
if (!parentDecl)
720
749
return false ;
721
- if (!isStdSetType (parentDecl))
750
+ if (!isStdSetType (parentDecl) && ! isStdMapType (parentDecl) )
722
751
return false ;
723
752
if (methodDecl->getDeclName ().isIdentifier () &&
724
753
methodDecl->getName () == " insert" )
@@ -747,24 +776,7 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
747
776
if (!valueType || !sizeType)
748
777
return ;
749
778
750
- auto insertId = ctx.getIdentifier (" __insertUnsafe" );
751
- auto inserts = lookupDirectWithoutExtensions (decl, insertId);
752
- FuncDecl *insert = nullptr ;
753
- for (auto candidate : inserts) {
754
- if (auto candidateMethod = dyn_cast<FuncDecl>(candidate)) {
755
- if (!candidateMethod->hasParameterList ())
756
- continue ;
757
- auto params = candidateMethod->getParameters ();
758
- if (params->size () != 1 )
759
- continue ;
760
- auto param = params->front ();
761
- if (param->getType ()->getCanonicalType () !=
762
- valueType->getUnderlyingType ()->getCanonicalType ())
763
- continue ;
764
- insert = candidateMethod;
765
- break ;
766
- }
767
- }
779
+ auto insert = getInsertFunc (decl, valueType);
768
780
if (!insert)
769
781
return ;
770
782
@@ -844,7 +856,7 @@ void swift::conformToCxxDictionaryIfNeeded(
844
856
845
857
// Only auto-conform types from the C++ standard library. Custom user types
846
858
// might have a similar interface but different semantics.
847
- if (!isStdDecl (clangDecl, { " map " , " unordered_map " } ))
859
+ if (!isStdMapType (clangDecl))
848
860
return ;
849
861
850
862
auto keyType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
@@ -853,7 +865,41 @@ void swift::conformToCxxDictionaryIfNeeded(
853
865
decl, ctx.getIdentifier (" mapped_type" ));
854
866
auto iterType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
855
867
decl, ctx.getIdentifier (" const_iterator" ));
856
- if (!keyType || !valueType || !iterType)
868
+ auto mutableIterType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
869
+ decl, ctx.getIdentifier (" iterator" ));
870
+ auto sizeType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
871
+ decl, ctx.getIdentifier (" size_type" ));
872
+ auto keyValuePairType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
873
+ decl, ctx.getIdentifier (" value_type" ));
874
+ if (!keyType || !valueType || !iterType || !mutableIterType || !sizeType ||
875
+ !keyValuePairType)
876
+ return ;
877
+
878
+ auto insert = getInsertFunc (decl, keyValuePairType);
879
+ if (!insert)
880
+ return ;
881
+
882
+ ProtocolDecl *cxxInputIteratorProto =
883
+ ctx.getProtocol (KnownProtocolKind::UnsafeCxxInputIterator);
884
+ ProtocolDecl *cxxMutableInputIteratorProto =
885
+ ctx.getProtocol (KnownProtocolKind::UnsafeCxxMutableInputIterator);
886
+ if (!cxxInputIteratorProto || !cxxMutableInputIteratorProto)
887
+ return ;
888
+
889
+ auto rawIteratorTy = iterType->getUnderlyingType ();
890
+ auto rawMutableIteratorTy = mutableIterType->getUnderlyingType ();
891
+
892
+ // Check if RawIterator conforms to UnsafeCxxInputIterator.
893
+ ModuleDecl *module = decl->getModuleContext ();
894
+ auto rawIteratorConformanceRef =
895
+ module ->lookupConformance (rawIteratorTy, cxxInputIteratorProto);
896
+ if (!isConcreteAndValid (rawIteratorConformanceRef, module ))
897
+ return ;
898
+
899
+ // Check if RawMutableIterator conforms to UnsafeCxxMutableInputIterator.
900
+ auto rawMutableIteratorConformanceRef = module ->lookupConformance (
901
+ rawMutableIteratorTy, cxxMutableInputIteratorProto);
902
+ if (!isConcreteAndValid (rawMutableIteratorConformanceRef, module ))
857
903
return ;
858
904
859
905
// Make the original subscript that returns a non-optional value unavailable.
@@ -869,7 +915,15 @@ void swift::conformToCxxDictionaryIfNeeded(
869
915
impl.addSynthesizedTypealias (decl, ctx.Id_Key , keyType->getUnderlyingType ());
870
916
impl.addSynthesizedTypealias (decl, ctx.Id_Value ,
871
917
valueType->getUnderlyingType ());
918
+ impl.addSynthesizedTypealias (decl, ctx.Id_Element ,
919
+ keyValuePairType->getUnderlyingType ());
872
920
impl.addSynthesizedTypealias (decl, ctx.getIdentifier (" RawIterator" ),
873
- iterType->getUnderlyingType ());
921
+ rawIteratorTy);
922
+ impl.addSynthesizedTypealias (decl, ctx.getIdentifier (" RawMutableIterator" ),
923
+ rawMutableIteratorTy);
924
+ impl.addSynthesizedTypealias (decl, ctx.getIdentifier (" Size" ),
925
+ sizeType->getUnderlyingType ());
926
+ impl.addSynthesizedTypealias (decl, ctx.getIdentifier (" InsertionResult" ),
927
+ insert->getResultInterfaceType ());
874
928
impl.addSynthesizedProtocolAttrs (decl, {KnownProtocolKind::CxxDictionary});
875
929
}
0 commit comments