@@ -340,52 +340,19 @@ swift::getDistributedSerializationRequirements(
340
340
if (existentialRequirementTy->isAny ())
341
341
return true ; // we're done here, any means there are no requirements
342
342
343
- if (!existentialRequirementTy->isExistentialType ()) {
344
- // SerializationRequirement must be an existential type
345
- return false ;
346
- }
347
-
348
- ExistentialType *serialReqType = existentialRequirementTy
349
- ->castTo <ExistentialType>();
343
+ auto *serialReqType = existentialRequirementTy->getAs <ExistentialType>();
350
344
if (!serialReqType || serialReqType->hasError ()) {
351
345
return false ;
352
346
}
353
347
354
- auto desugaredTy = serialReqType->getConstraintType ()->getDesugaredType ();
355
- auto flattenedRequirements =
356
- flattenDistributedSerializationTypeToRequiredProtocols (
357
- desugaredTy);
358
- for (auto p : flattenedRequirements) {
348
+ auto layout = serialReqType->getExistentialLayout ();
349
+ for (auto p : layout.getProtocols ()) {
359
350
requirementProtos.insert (p);
360
351
}
361
352
362
353
return true ;
363
354
}
364
355
365
- llvm::SmallPtrSet<ProtocolDecl *, 2 >
366
- swift::flattenDistributedSerializationTypeToRequiredProtocols (
367
- TypeBase *serializationRequirement) {
368
- llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationReqs;
369
- if (auto composition =
370
- serializationRequirement->getAs <ProtocolCompositionType>()) {
371
- for (auto member : composition->getMembers ()) {
372
- if (auto comp = member->getAs <ProtocolCompositionType>()) {
373
- for (auto protocol :
374
- flattenDistributedSerializationTypeToRequiredProtocols (comp)) {
375
- serializationReqs.insert (protocol);
376
- }
377
- } else if (auto *protocol = member->getAs <ProtocolType>()) {
378
- serializationReqs.insert (protocol->getDecl ());
379
- }
380
- }
381
- } else {
382
- auto protocol = serializationRequirement->castTo <ProtocolType>()->getDecl ();
383
- serializationReqs.insert (protocol);
384
- }
385
-
386
- return serializationReqs;
387
- }
388
-
389
356
bool swift::checkDistributedSerializationRequirementIsExactlyCodable (
390
357
ASTContext &C,
391
358
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements) {
@@ -565,25 +532,19 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
565
532
566
533
// --- Check requirement: conforms_to: Act DistributedActor
567
534
auto actorReq = requirements[0 ];
568
- auto distActorTy = C.getProtocol (KnownProtocolKind::DistributedActor)
569
- ->getInterfaceType ()
570
- ->getMetatypeInstanceType ();
571
535
if (actorReq.getKind () != RequirementKind::Conformance) {
572
536
return false ;
573
537
}
574
- if (!actorReq.getSecondType ()->isEqual (distActorTy )) {
538
+ if (!actorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::DistributedActor )) {
575
539
return false ;
576
540
}
577
541
578
542
// --- Check requirement: conforms_to: Err Error
579
543
auto errorReq = requirements[1 ];
580
- auto errorTy = C.getProtocol (KnownProtocolKind::Error)
581
- ->getInterfaceType ()
582
- ->getMetatypeInstanceType ();
583
544
if (errorReq.getKind () != RequirementKind::Conformance) {
584
545
return false ;
585
546
}
586
- if (!errorReq.getSecondType ()->isEqual (errorTy )) {
547
+ if (!errorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::Error )) {
587
548
return false ;
588
549
}
589
550
@@ -598,10 +559,9 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
598
559
assert (ResParam && " Non void function, yet no Res generic parameter found" );
599
560
if (auto func = dyn_cast<FuncDecl>(this )) {
600
561
auto resultType = func->mapTypeIntoContext (func->getResultInterfaceType ())
601
- ->getMetatypeInstanceType ()
602
- ->getDesugaredType ();
562
+ ->getMetatypeInstanceType ();
603
563
auto resultParamType = func->mapTypeIntoContext (
604
- ResParam->getInterfaceType ()-> getMetatypeInstanceType ());
564
+ ResParam->getDeclaredInterfaceType ());
605
565
// The result of the function must be the `Res` generic argument.
606
566
if (!resultType->isEqual (resultParamType)) {
607
567
return false ;
@@ -797,12 +757,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
797
757
798
758
// the <Value> of the RemoteCallArgument<Value>
799
759
auto remoteCallArgValueGenericTy =
800
- mapTypeIntoContext (argGenericParams[0 ]->getInterfaceType ())
801
- ->getDesugaredType ()
802
- ->getMetatypeInstanceType ();
760
+ mapTypeIntoContext (argGenericParams[0 ]->getDeclaredInterfaceType ());
803
761
// expected (the <Value> from the recordArgument<Value>)
804
762
auto expectedGenericParamTy = mapTypeIntoContext (
805
- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
763
+ ArgumentParam->getDeclaredInterfaceType ());
806
764
807
765
if (!remoteCallArgValueGenericTy->isEqual (expectedGenericParamTy)) {
808
766
return false ;
@@ -932,11 +890,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() con
932
890
// ...
933
891
934
892
auto resultType = func->mapTypeIntoContext (argumentParam->getInterfaceType ())
935
- ->getMetatypeInstanceType ()
936
- ->getDesugaredType ();
893
+ ->getMetatypeInstanceType ();
937
894
938
895
auto resultParamType = func->mapTypeIntoContext (
939
- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
896
+ ArgumentParam->getDeclaredInterfaceType ());
940
897
941
898
// The result of the function must be the `Res` generic argument.
942
899
if (!resultType->isEqual (resultParamType)) {
@@ -1046,13 +1003,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
1046
1003
1047
1004
// --- Check requirement: conforms_to: Err Error
1048
1005
auto errorReq = requirements[0 ];
1049
- auto errorTy = C.getProtocol (KnownProtocolKind::Error)
1050
- ->getInterfaceType ()
1051
- ->getMetatypeInstanceType ();
1052
1006
if (errorReq.getKind () != RequirementKind::Conformance) {
1053
1007
return false ;
1054
1008
}
1055
- if (!errorReq.getSecondType ()->isEqual (errorTy )) {
1009
+ if (!errorReq.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::Error )) {
1056
1010
return false ;
1057
1011
}
1058
1012
@@ -1139,10 +1093,9 @@ AbstractFunctionDecl::isDistributedTargetInvocationDecoderDecodeNextArgument() c
1139
1093
// --- Check: Argument: SerializationRequirement
1140
1094
GenericTypeParamDecl *ArgumentParam = genericParams->getParams ()[0 ];
1141
1095
auto resultType = func->mapTypeIntoContext (func->getResultInterfaceType ())
1142
- ->getMetatypeInstanceType ()
1143
- ->getDesugaredType ();
1096
+ ->getMetatypeInstanceType ();
1144
1097
auto resultParamType = func->mapTypeIntoContext (
1145
- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
1098
+ ArgumentParam->getDeclaredInterfaceType ());
1146
1099
// The result of the function must be the `Res` generic argument.
1147
1100
if (!resultType->isEqual (resultParamType)) {
1148
1101
return false ;
@@ -1237,11 +1190,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
1237
1190
// === Check generic parameters in detail
1238
1191
// --- Check: Argument: SerializationRequirement
1239
1192
GenericTypeParamDecl *ArgumentParam = genericParams->getParams ()[0 ];
1240
- auto argumentType = func->mapTypeIntoContext (valueParam->getInterfaceType ())
1241
- ->getMetatypeInstanceType ()
1242
- ->getDesugaredType ();
1193
+ auto argumentType = func->mapTypeIntoContext (
1194
+ valueParam->getInterfaceType ()->getMetatypeInstanceType ());
1243
1195
auto resultParamType = func->mapTypeIntoContext (
1244
- ArgumentParam->getInterfaceType ()-> getMetatypeInstanceType ());
1196
+ ArgumentParam->getDeclaredInterfaceType ());
1245
1197
// The result of the function must be the `Res` generic argument.
1246
1198
if (!argumentType->isEqual (resultParamType)) {
1247
1199
return false ;
@@ -1269,35 +1221,19 @@ swift::extractDistributedSerializationRequirements(
1269
1221
auto DA = C.getDistributedActorDecl ();
1270
1222
auto daSerializationReqAssocType =
1271
1223
DA->getAssociatedType (C.Id_SerializationRequirement );
1272
- auto daSystemSerializationReqTy = daSerializationReqAssocType->getInterfaceType ();
1273
1224
1274
1225
for (auto req : allRequirements) {
1275
- if (req.getSecondType ()->isAny ()) {
1276
- continue ;
1277
- }
1278
- if (!req.getFirstType ()->hasDependentMember ())
1226
+ // FIXME: Seems unprincipled
1227
+ if (req.getKind () != RequirementKind::SameType &&
1228
+ req.getKind () != RequirementKind::Conformance)
1279
1229
continue ;
1280
1230
1281
1231
if (auto dependentMemberType =
1282
- req.getFirstType ()->castTo <DependentMemberType>()) {
1283
- auto dependentTy =
1284
- dependentMemberType->getAssocType ()->getInterfaceType ();
1285
-
1286
- if (dependentTy->isEqual (daSystemSerializationReqTy)) {
1287
- auto requirementProto = req.getSecondType ();
1288
- if (auto proto = dyn_cast_or_null<ProtocolDecl>(
1289
- requirementProto->getAnyNominal ())) {
1290
- serializationReqs.insert (proto);
1291
- } else {
1292
- auto serialReqType = requirementProto->castTo <ExistentialType>()
1293
- ->getConstraintType ()
1294
- ->getDesugaredType ();
1295
- auto flattenedRequirements =
1296
- flattenDistributedSerializationTypeToRequiredProtocols (
1297
- serialReqType);
1298
- for (auto p : flattenedRequirements) {
1299
- serializationReqs.insert (p);
1300
- }
1232
+ req.getFirstType ()->getAs <DependentMemberType>()) {
1233
+ if (dependentMemberType->getAssocType () == daSerializationReqAssocType) {
1234
+ auto layout = req.getSecondType ()->getExistentialLayout ();
1235
+ for (auto p : layout.getProtocols ()) {
1236
+ serializationReqs.insert (p);
1301
1237
}
1302
1238
}
1303
1239
}
0 commit comments