@@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member)
95
95
llvm_unreachable (" Unable to fetch ActorSystem type!" );
96
96
}
97
97
98
- Type swift::getConcreteReplacementForMemberSerializationRequirement (
99
- ValueDecl *member) {
98
+ Type swift::getSerializationRequirementTypesForMember (
99
+ ValueDecl *member,
100
+ llvm::SmallPtrSet<ProtocolDecl *, 2 > &serializationRequirements) {
100
101
auto &C = member->getASTContext ();
101
102
auto *DC = member->getDeclContext ();
102
103
auto DA = C.getDistributedActorDecl ();
@@ -106,17 +107,21 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement(
106
107
return getDistributedSerializationRequirementType (classDecl, C.getDistributedActorDecl ());
107
108
}
108
109
109
- // / === Maybe the value is declared in a protocol?
110
- if (auto protocol = DC->getSelfProtocolDecl ()) {
110
+ auto SerReqAssocType = DA->getAssociatedType (C.Id_SerializationRequirement )
111
+ ->getDeclaredInterfaceType ();
112
+
113
+ if (DC->getSelfProtocolDecl ()) {
111
114
GenericSignature signature;
112
115
if (auto *genericContext = member->getAsGenericContext ()) {
113
116
signature = genericContext->getGenericSignature ();
114
117
} else {
115
118
signature = DC->getGenericSignatureOfContext ();
116
119
}
117
120
118
- auto SerReqAssocType = DA->getAssociatedType (C.Id_SerializationRequirement )
119
- ->getDeclaredInterfaceType ();
121
+ // Also store all `SerializationRequirement : SomeProtocol` requirements
122
+ for (auto proto: signature->getRequiredProtocols (SerReqAssocType)) {
123
+ serializationRequirements.insert (proto);
124
+ }
120
125
121
126
// Note that this may be null, e.g. if we're a distributed func inside
122
127
// a protocol that did not declare a specific actor system requirement.
@@ -178,13 +183,7 @@ Type swift::getDistributedActorSystemResultHandlerType(
178
183
auto module = system->getParentModule ();
179
184
Type selfType = system->getSelfInterfaceType ();
180
185
auto conformance = module ->lookupConformance (selfType, DAS);
181
- auto witness =
182
- conformance.getTypeWitnessByName (selfType, ctx.Id_ResultHandler );
183
- if (auto alias = dyn_cast<TypeAliasType>(witness.getPointer ())) {
184
- return alias->getDecl ()->getUnderlyingType ();
185
- } else {
186
- return witness;
187
- }
186
+ return conformance.getTypeWitnessByName (selfType, ctx.Id_ResultHandler );
188
187
}
189
188
190
189
Type swift::getDistributedActorSystemInvocationEncoderType (NominalTypeDecl *system) {
@@ -346,63 +345,39 @@ swift::getDistributedSerializationRequirements(
346
345
if (existentialRequirementTy->isAny ())
347
346
return true ; // we're done here, any means there are no requirements
348
347
349
- if (!existentialRequirementTy->isExistentialType ()) {
350
- // SerializationRequirement must be an existential type
351
- return false ;
352
- }
353
-
354
- ExistentialType *serialReqType = existentialRequirementTy
355
- ->castTo <ExistentialType>();
348
+ auto *serialReqType = existentialRequirementTy->getAs <ExistentialType>();
356
349
if (!serialReqType || serialReqType->hasError ()) {
357
350
return false ;
358
351
}
359
352
360
- auto desugaredTy = serialReqType->getConstraintType ()->getDesugaredType ();
361
- auto flattenedRequirements =
362
- flattenDistributedSerializationTypeToRequiredProtocols (
363
- desugaredTy);
364
- for (auto p : flattenedRequirements) {
353
+ auto layout = serialReqType->getExistentialLayout ();
354
+ for (auto p : layout.getProtocols ()) {
365
355
requirementProtos.insert (p);
366
356
}
367
357
368
358
return true ;
369
359
}
370
360
371
- llvm::SmallPtrSet<ProtocolDecl *, 2 >
372
- swift::flattenDistributedSerializationTypeToRequiredProtocols (
373
- TypeBase *serializationRequirement) {
374
- llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationReqs;
375
- if (auto composition =
376
- serializationRequirement->getAs <ProtocolCompositionType>()) {
377
- for (auto member : composition->getMembers ()) {
378
- if (auto comp = member->getAs <ProtocolCompositionType>()) {
379
- for (auto protocol :
380
- flattenDistributedSerializationTypeToRequiredProtocols (comp)) {
381
- serializationReqs.insert (protocol);
382
- }
383
- } else if (auto *protocol = member->getAs <ProtocolType>()) {
384
- serializationReqs.insert (protocol->getDecl ());
385
- }
386
- }
387
- } else {
388
- auto protocol = serializationRequirement->castTo <ProtocolType>()->getDecl ();
389
- serializationReqs.insert (protocol);
390
- }
391
-
392
- return serializationReqs;
393
- }
394
-
395
361
bool swift::checkDistributedSerializationRequirementIsExactlyCodable (
396
362
ASTContext &C,
397
- const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements) {
363
+ Type type) {
364
+ if (!type)
365
+ return false ;
366
+
367
+ if (type->hasError ())
368
+ return false ;
369
+
398
370
auto encodable = C.getProtocol (KnownProtocolKind::Encodable);
399
371
auto decodable = C.getProtocol (KnownProtocolKind::Decodable);
400
372
401
- if (allRequirements.size () != 2 )
373
+ auto layout = type->getExistentialLayout ();
374
+ auto protocols = layout.getProtocols ();
375
+
376
+ if (protocols.size () != 2 )
402
377
return false ;
403
378
404
- return allRequirements. count (encodable) &&
405
- allRequirements. count ( decodable);
379
+ return std:: count (protocols. begin (), protocols. end (), encodable) == 1 &&
380
+ std::count (protocols. begin (), protocols. end (), decodable) == 1 ;
406
381
}
407
382
408
383
/* *****************************************************************************/
@@ -571,25 +546,19 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
571
546
572
547
// --- Check requirement: conforms_to: Act DistributedActor
573
548
auto actorReq = requirements[0 ];
574
- auto distActorTy = C.getProtocol (KnownProtocolKind::DistributedActor)
575
- ->getInterfaceType ()
576
- ->getMetatypeInstanceType ();
577
549
if (actorReq.getKind () != RequirementKind::Conformance) {
578
550
return false ;
579
551
}
580
- if (!actorReq.getSecondType ()->isEqual (distActorTy )) {
552
+ if (!actorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::DistributedActor )) {
581
553
return false ;
582
554
}
583
555
584
556
// --- Check requirement: conforms_to: Err Error
585
557
auto errorReq = requirements[1 ];
586
- auto errorTy = C.getProtocol (KnownProtocolKind::Error)
587
- ->getInterfaceType ()
588
- ->getMetatypeInstanceType ();
589
558
if (errorReq.getKind () != RequirementKind::Conformance) {
590
559
return false ;
591
560
}
592
- if (!errorReq.getSecondType ()->isEqual (errorTy )) {
561
+ if (!errorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::Error )) {
593
562
return false ;
594
563
}
595
564
@@ -604,10 +573,9 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
604
573
assert (ResParam && " Non void function, yet no Res generic parameter found" );
605
574
if (auto func = dyn_cast<FuncDecl>(this )) {
606
575
auto resultType = func->mapTypeIntoContext (func->getResultInterfaceType ())
607
- ->getMetatypeInstanceType ()
608
- ->getDesugaredType ();
576
+ ->getMetatypeInstanceType ();
609
577
auto resultParamType = func->mapTypeIntoContext (
610
- ResParam->getInterfaceType ()-> getMetatypeInstanceType ());
578
+ ResParam->getDeclaredInterfaceType ());
611
579
// The result of the function must be the `Res` generic argument.
612
580
if (!resultType->isEqual (resultParamType)) {
613
581
return false ;
@@ -803,12 +771,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
803
771
804
772
// the <Value> of the RemoteCallArgument<Value>
805
773
auto remoteCallArgValueGenericTy =
806
- mapTypeIntoContext (argGenericParams[0 ]->getInterfaceType ())
807
- ->getDesugaredType ()
808
- ->getMetatypeInstanceType ();
774
+ mapTypeIntoContext (argGenericParams[0 ]->getDeclaredInterfaceType ());
809
775
// expected (the <Value> from the recordArgument<Value>)
810
776
auto expectedGenericParamTy = mapTypeIntoContext (
811
- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
777
+ ArgumentParam->getDeclaredInterfaceType ());
812
778
813
779
if (!remoteCallArgValueGenericTy->isEqual (expectedGenericParamTy)) {
814
780
return false ;
@@ -938,11 +904,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() con
938
904
// ...
939
905
940
906
auto resultType = func->mapTypeIntoContext (argumentParam->getInterfaceType ())
941
- ->getMetatypeInstanceType ()
942
- ->getDesugaredType ();
907
+ ->getMetatypeInstanceType ();
943
908
944
909
auto resultParamType = func->mapTypeIntoContext (
945
- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
910
+ ArgumentParam->getDeclaredInterfaceType ());
946
911
947
912
// The result of the function must be the `Res` generic argument.
948
913
if (!resultType->isEqual (resultParamType)) {
@@ -1052,13 +1017,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
1052
1017
1053
1018
// --- Check requirement: conforms_to: Err Error
1054
1019
auto errorReq = requirements[0 ];
1055
- auto errorTy = C.getProtocol (KnownProtocolKind::Error)
1056
- ->getInterfaceType ()
1057
- ->getMetatypeInstanceType ();
1058
1020
if (errorReq.getKind () != RequirementKind::Conformance) {
1059
1021
return false ;
1060
1022
}
1061
- if (!errorReq.getSecondType ()->isEqual (errorTy )) {
1023
+ if (!errorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::Error )) {
1062
1024
return false ;
1063
1025
}
1064
1026
@@ -1145,10 +1107,9 @@ AbstractFunctionDecl::isDistributedTargetInvocationDecoderDecodeNextArgument() c
1145
1107
// --- Check: Argument: SerializationRequirement
1146
1108
GenericTypeParamDecl *ArgumentParam = genericParams->getParams ()[0 ];
1147
1109
auto resultType = func->mapTypeIntoContext (func->getResultInterfaceType ())
1148
- ->getMetatypeInstanceType ()
1149
- ->getDesugaredType ();
1110
+ ->getMetatypeInstanceType ();
1150
1111
auto resultParamType = func->mapTypeIntoContext (
1151
- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
1112
+ ArgumentParam->getDeclaredInterfaceType ());
1152
1113
// The result of the function must be the `Res` generic argument.
1153
1114
if (!resultType->isEqual (resultParamType)) {
1154
1115
return false ;
@@ -1243,11 +1204,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
1243
1204
// === Check generic parameters in detail
1244
1205
// --- Check: Argument: SerializationRequirement
1245
1206
GenericTypeParamDecl *ArgumentParam = genericParams->getParams ()[0 ];
1246
- auto argumentType = func->mapTypeIntoContext (valueParam->getInterfaceType ())
1247
- ->getMetatypeInstanceType ()
1248
- ->getDesugaredType ();
1207
+ auto argumentType = func->mapTypeIntoContext (
1208
+ valueParam->getInterfaceType ()->getMetatypeInstanceType ());
1249
1209
auto resultParamType = func->mapTypeIntoContext (
1250
- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
1210
+ ArgumentParam->getDeclaredInterfaceType ());
1251
1211
// The result of the function must be the `Res` generic argument.
1252
1212
if (!argumentType->isEqual (resultParamType)) {
1253
1213
return false ;
@@ -1268,50 +1228,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
1268
1228
return true ;
1269
1229
}
1270
1230
1271
- llvm::SmallPtrSet<ProtocolDecl *, 2 >
1272
- swift::extractDistributedSerializationRequirements (
1273
- ASTContext &C, ArrayRef<Requirement> allRequirements) {
1274
- llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationReqs;
1275
- auto DA = C.getDistributedActorDecl ();
1276
- auto daSerializationReqAssocType =
1277
- DA->getAssociatedType (C.Id_SerializationRequirement );
1278
- auto daSystemSerializationReqTy = daSerializationReqAssocType->getInterfaceType ();
1279
-
1280
- for (auto req : allRequirements) {
1281
- if (req.getSecondType ()->isAny ()) {
1282
- continue ;
1283
- }
1284
- if (!req.getFirstType ()->hasDependentMember ())
1285
- continue ;
1286
-
1287
- if (auto dependentMemberType =
1288
- req.getFirstType ()->castTo <DependentMemberType>()) {
1289
- auto dependentTy =
1290
- dependentMemberType->getAssocType ()->getInterfaceType ();
1291
-
1292
- if (dependentTy->isEqual (daSystemSerializationReqTy)) {
1293
- auto requirementProto = req.getSecondType ();
1294
- if (auto proto = dyn_cast_or_null<ProtocolDecl>(
1295
- requirementProto->getAnyNominal ())) {
1296
- serializationReqs.insert (proto);
1297
- } else {
1298
- auto serialReqType = requirementProto->castTo <ExistentialType>()
1299
- ->getConstraintType ()
1300
- ->getDesugaredType ();
1301
- auto flattenedRequirements =
1302
- flattenDistributedSerializationTypeToRequiredProtocols (
1303
- serialReqType);
1304
- for (auto p : flattenedRequirements) {
1305
- serializationReqs.insert (p);
1306
- }
1307
- }
1308
- }
1309
- }
1310
- }
1311
-
1312
- return serializationReqs;
1313
- }
1314
-
1315
1231
/* *****************************************************************************/
1316
1232
/* ********************* Distributed Functions *********************************/
1317
1233
/* *****************************************************************************/
0 commit comments