Skip to content

Commit be7198f

Browse files
authored
Merge pull request #69778 from slavapestov/random-distributed-cleanup-2
Random distributed cleanup
2 parents f5277ae + 19b3d09 commit be7198f

File tree

3 files changed

+39
-112
lines changed

3 files changed

+39
-112
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,6 @@ llvm::SmallPtrSet<ProtocolDecl *, 2>
9191
getDistributedSerializationRequirementProtocols(
9292
NominalTypeDecl *decl, ProtocolDecl* protocol);
9393

94-
/// Desugar and flatten the `SerializationRequirement` type into a set of
95-
/// specific protocol declarations.
96-
llvm::SmallPtrSet<ProtocolDecl *, 2>
97-
flattenDistributedSerializationTypeToRequiredProtocols(
98-
TypeBase *serializationRequirement);
99-
10094
/// Check if the `allRequirements` represent *exactly* the
10195
/// `Encodable & Decodable` (also known as `Codable`) requirement.
10296
///

lib/AST/DistributedDecl.cpp

Lines changed: 25 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -340,52 +340,19 @@ swift::getDistributedSerializationRequirements(
340340
if (existentialRequirementTy->isAny())
341341
return true; // we're done here, any means there are no requirements
342342

343-
if (!existentialRequirementTy->isExistentialType()) {
344-
// SerializationRequirement must be an existential type
345-
return false;
346-
}
347-
348-
ExistentialType *serialReqType = existentialRequirementTy
349-
->castTo<ExistentialType>();
343+
auto *serialReqType = existentialRequirementTy->getAs<ExistentialType>();
350344
if (!serialReqType || serialReqType->hasError()) {
351345
return false;
352346
}
353347

354-
auto desugaredTy = serialReqType->getConstraintType()->getDesugaredType();
355-
auto flattenedRequirements =
356-
flattenDistributedSerializationTypeToRequiredProtocols(
357-
desugaredTy);
358-
for (auto p : flattenedRequirements) {
348+
auto layout = serialReqType->getExistentialLayout();
349+
for (auto p : layout.getProtocols()) {
359350
requirementProtos.insert(p);
360351
}
361352

362353
return true;
363354
}
364355

365-
llvm::SmallPtrSet<ProtocolDecl *, 2>
366-
swift::flattenDistributedSerializationTypeToRequiredProtocols(
367-
TypeBase *serializationRequirement) {
368-
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs;
369-
if (auto composition =
370-
serializationRequirement->getAs<ProtocolCompositionType>()) {
371-
for (auto member : composition->getMembers()) {
372-
if (auto comp = member->getAs<ProtocolCompositionType>()) {
373-
for (auto protocol :
374-
flattenDistributedSerializationTypeToRequiredProtocols(comp)) {
375-
serializationReqs.insert(protocol);
376-
}
377-
} else if (auto *protocol = member->getAs<ProtocolType>()) {
378-
serializationReqs.insert(protocol->getDecl());
379-
}
380-
}
381-
} else {
382-
auto protocol = serializationRequirement->castTo<ProtocolType>()->getDecl();
383-
serializationReqs.insert(protocol);
384-
}
385-
386-
return serializationReqs;
387-
}
388-
389356
bool swift::checkDistributedSerializationRequirementIsExactlyCodable(
390357
ASTContext &C,
391358
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements) {
@@ -565,25 +532,19 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
565532

566533
// --- Check requirement: conforms_to: Act DistributedActor
567534
auto actorReq = requirements[0];
568-
auto distActorTy = C.getProtocol(KnownProtocolKind::DistributedActor)
569-
->getInterfaceType()
570-
->getMetatypeInstanceType();
571535
if (actorReq.getKind() != RequirementKind::Conformance) {
572536
return false;
573537
}
574-
if (!actorReq.getSecondType()->isEqual(distActorTy)) {
538+
if (!actorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::DistributedActor)) {
575539
return false;
576540
}
577541

578542
// --- Check requirement: conforms_to: Err Error
579543
auto errorReq = requirements[1];
580-
auto errorTy = C.getProtocol(KnownProtocolKind::Error)
581-
->getInterfaceType()
582-
->getMetatypeInstanceType();
583544
if (errorReq.getKind() != RequirementKind::Conformance) {
584545
return false;
585546
}
586-
if (!errorReq.getSecondType()->isEqual(errorTy)) {
547+
if (!errorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Error)) {
587548
return false;
588549
}
589550

@@ -598,10 +559,9 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
598559
assert(ResParam && "Non void function, yet no Res generic parameter found");
599560
if (auto func = dyn_cast<FuncDecl>(this)) {
600561
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
601-
->getMetatypeInstanceType()
602-
->getDesugaredType();
562+
->getMetatypeInstanceType();
603563
auto resultParamType = func->mapTypeIntoContext(
604-
ResParam->getInterfaceType()->getMetatypeInstanceType());
564+
ResParam->getDeclaredInterfaceType());
605565
// The result of the function must be the `Res` generic argument.
606566
if (!resultType->isEqual(resultParamType)) {
607567
return false;
@@ -797,12 +757,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
797757

798758
// the <Value> of the RemoteCallArgument<Value>
799759
auto remoteCallArgValueGenericTy =
800-
mapTypeIntoContext(argGenericParams[0]->getInterfaceType())
801-
->getDesugaredType()
802-
->getMetatypeInstanceType();
760+
mapTypeIntoContext(argGenericParams[0]->getDeclaredInterfaceType());
803761
// expected (the <Value> from the recordArgument<Value>)
804762
auto expectedGenericParamTy = mapTypeIntoContext(
805-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
763+
ArgumentParam->getDeclaredInterfaceType());
806764

807765
if (!remoteCallArgValueGenericTy->isEqual(expectedGenericParamTy)) {
808766
return false;
@@ -932,11 +890,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() con
932890
// ...
933891

934892
auto resultType = func->mapTypeIntoContext(argumentParam->getInterfaceType())
935-
->getMetatypeInstanceType()
936-
->getDesugaredType();
893+
->getMetatypeInstanceType();
937894

938895
auto resultParamType = func->mapTypeIntoContext(
939-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
896+
ArgumentParam->getDeclaredInterfaceType());
940897

941898
// The result of the function must be the `Res` generic argument.
942899
if (!resultType->isEqual(resultParamType)) {
@@ -1046,13 +1003,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
10461003

10471004
// --- Check requirement: conforms_to: Err Error
10481005
auto errorReq = requirements[0];
1049-
auto errorTy = C.getProtocol(KnownProtocolKind::Error)
1050-
->getInterfaceType()
1051-
->getMetatypeInstanceType();
10521006
if (errorReq.getKind() != RequirementKind::Conformance) {
10531007
return false;
10541008
}
1055-
if (!errorReq.getSecondType()->isEqual(errorTy)) {
1009+
if (!errorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Error)) {
10561010
return false;
10571011
}
10581012

@@ -1139,10 +1093,9 @@ AbstractFunctionDecl::isDistributedTargetInvocationDecoderDecodeNextArgument() c
11391093
// --- Check: Argument: SerializationRequirement
11401094
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
11411095
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
1142-
->getMetatypeInstanceType()
1143-
->getDesugaredType();
1096+
->getMetatypeInstanceType();
11441097
auto resultParamType = func->mapTypeIntoContext(
1145-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
1098+
ArgumentParam->getDeclaredInterfaceType());
11461099
// The result of the function must be the `Res` generic argument.
11471100
if (!resultType->isEqual(resultParamType)) {
11481101
return false;
@@ -1237,11 +1190,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
12371190
// === Check generic parameters in detail
12381191
// --- Check: Argument: SerializationRequirement
12391192
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
1240-
auto argumentType = func->mapTypeIntoContext(valueParam->getInterfaceType())
1241-
->getMetatypeInstanceType()
1242-
->getDesugaredType();
1193+
auto argumentType = func->mapTypeIntoContext(
1194+
valueParam->getInterfaceType()->getMetatypeInstanceType());
12431195
auto resultParamType = func->mapTypeIntoContext(
1244-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
1196+
ArgumentParam->getDeclaredInterfaceType());
12451197
// The result of the function must be the `Res` generic argument.
12461198
if (!argumentType->isEqual(resultParamType)) {
12471199
return false;
@@ -1269,35 +1221,19 @@ swift::extractDistributedSerializationRequirements(
12691221
auto DA = C.getDistributedActorDecl();
12701222
auto daSerializationReqAssocType =
12711223
DA->getAssociatedType(C.Id_SerializationRequirement);
1272-
auto daSystemSerializationReqTy = daSerializationReqAssocType->getInterfaceType();
12731224

12741225
for (auto req : allRequirements) {
1275-
if (req.getSecondType()->isAny()) {
1276-
continue;
1277-
}
1278-
if (!req.getFirstType()->hasDependentMember())
1226+
// FIXME: Seems unprincipled
1227+
if (req.getKind() != RequirementKind::SameType &&
1228+
req.getKind() != RequirementKind::Conformance)
12791229
continue;
12801230

12811231
if (auto dependentMemberType =
1282-
req.getFirstType()->castTo<DependentMemberType>()) {
1283-
auto dependentTy =
1284-
dependentMemberType->getAssocType()->getInterfaceType();
1285-
1286-
if (dependentTy->isEqual(daSystemSerializationReqTy)) {
1287-
auto requirementProto = req.getSecondType();
1288-
if (auto proto = dyn_cast_or_null<ProtocolDecl>(
1289-
requirementProto->getAnyNominal())) {
1290-
serializationReqs.insert(proto);
1291-
} else {
1292-
auto serialReqType = requirementProto->castTo<ExistentialType>()
1293-
->getConstraintType()
1294-
->getDesugaredType();
1295-
auto flattenedRequirements =
1296-
flattenDistributedSerializationTypeToRequiredProtocols(
1297-
serialReqType);
1298-
for (auto p : flattenedRequirements) {
1299-
serializationReqs.insert(p);
1300-
}
1232+
req.getFirstType()->getAs<DependentMemberType>()) {
1233+
if (dependentMemberType->getAssocType() == daSerializationReqAssocType) {
1234+
auto layout = req.getSecondType()->getExistentialLayout();
1235+
for (auto p : layout.getProtocols()) {
1236+
serializationReqs.insert(p);
13011237
}
13021238
}
13031239
}

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,8 @@ bool CheckDistributedFunctionRequest::evaluate(
507507
} else if (isa<ProtocolDecl>(DC)) {
508508
if (auto seqReqTy =
509509
getConcreteReplacementForMemberSerializationRequirement(func)) {
510-
auto seqReqTyDes = seqReqTy->castTo<ExistentialType>()->getConstraintType()->getDesugaredType();
511-
for (auto req : flattenDistributedSerializationTypeToRequiredProtocols(seqReqTyDes)) {
510+
auto layout = seqReqTy->getExistentialLayout();
511+
for (auto req : layout.getProtocols()) {
512512
serializationRequirements.insert(req);
513513
}
514514
}
@@ -759,11 +759,13 @@ swift::getDistributedSerializationRequirementProtocols(
759759
return {};
760760
}
761761

762-
auto serialReqType =
763-
ty->castTo<ExistentialType>()->getConstraintType()->getDesugaredType();
764-
765762
// TODO(distributed): check what happens with Any
766-
return flattenDistributedSerializationTypeToRequiredProtocols(serialReqType);
763+
auto layout = ty->getExistentialLayout();
764+
llvm::SmallPtrSet<ProtocolDecl *, 2> result;
765+
for (auto p : layout.getProtocols()) {
766+
result.insert(p);
767+
}
768+
return result;
767769
}
768770

769771
ConstructorDecl*
@@ -887,29 +889,24 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator,
887889
continue;
888890

889891
auto paramTy = genericParamList->getParams()[0]
890-
->getInterfaceType()
891-
->getMetatypeInstanceType();
892+
->getDeclaredInterfaceType();
892893

893894
// `decodeNextArgument` should return its generic parameter value
894895
if (!FD->getResultInterfaceType()->isEqual(paramTy))
895896
continue;
896897

897898
// Let's find out how many serialization requirements does this method cover
898899
// e.g. `Codable` is two requirements - `Encodable` and `Decodable`.
899-
unsigned numSerializationReqsCovered = llvm::count_if(
900-
FD->getGenericRequirements(), [&](const Requirement &requirement) {
901-
if (!(requirement.getFirstType()->isEqual(paramTy) &&
902-
requirement.getKind() == RequirementKind::Conformance))
903-
return 0;
904-
905-
return serializationReqs.count(requirement.getProtocolDecl()) ? 1 : 0;
906-
});
900+
bool okay = llvm::all_of(serializationReqs,
901+
[&](ProtocolDecl *p) -> bool {
902+
return FD->getGenericSignature()->requiresProtocol(paramTy, p);
903+
});
907904

908905
// If the current method covers all of the serialization requirements,
909906
// it's a match. Note that it might also have other requirements, but
910907
// we let that go as long as there are no two candidates that differ
911908
// only in generic requirements.
912-
if (numSerializationReqsCovered == serializationReqs.size())
909+
if (okay)
913910
candidates.push_back(FD);
914911
}
915912

0 commit comments

Comments
 (0)