Skip to content

Commit 5a44f7f

Browse files
committed
[Distributed] Sema: Adjust decoding method lookup to check for covered serialization requirements
A proper `decodeNextArgument` candidate should cover all of the protocols listed in `SerializationRequirement` associated with distributed actor it would be used for.
1 parent 2fe4417 commit 5a44f7f

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ IDENTIFIER(using)
292292
IDENTIFIER(InvocationDecoder)
293293
IDENTIFIER(whenLocal)
294294
IDENTIFIER(decodeNextArgument)
295+
IDENTIFIER(SerializationRequirement)
295296

296297
#undef IDENTIFIER
297298
#undef IDENTIFIER_

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -472,26 +472,66 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator,
472472
auto members = TypeChecker::lookupMember(actor->getDeclContext(), decoderTy,
473473
DeclNameRef(ctx.Id_decodeNextArgument));
474474

475+
// typealias SerializationRequirement = any ...
476+
auto serializerType = getAssociatedTypeOfDistributedSystem(
477+
actor, ctx.Id_SerializationRequirement)
478+
->castTo<ExistentialType>()
479+
->getConstraintType()
480+
->getDesugaredType();
481+
482+
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs;
483+
if (auto composition = serializerType->getAs<ProtocolCompositionType>()) {
484+
for (auto member : composition->getMembers()) {
485+
if (auto *protocol = member->getAs<ProtocolType>())
486+
serializationReqs.insert(protocol->getDecl());
487+
}
488+
} else {
489+
auto protocol = serializerType->castTo<ProtocolType>()->getDecl();
490+
serializationReqs.insert(protocol);
491+
}
492+
475493
SmallVector<FuncDecl *, 2> candidates;
476-
// Looking for `decodeNextArgument<Arg>() throws -> Arg`
494+
// Looking for `decodeNextArgument<Arg: <SerializationReq>>() throws -> Arg`
477495
for (auto &member : members) {
478496
auto *FD = dyn_cast<FuncDecl>(member.getValueDecl());
479497
if (!FD || FD->hasAsync() || !FD->hasThrows())
480498
continue;
481499

482500
auto *params = FD->getParameters();
501+
// No arguemnts.
483502
if (params->size() != 0)
484503
continue;
485504

486505
auto genericParamList = FD->getGenericParams();
487-
if (genericParamList->size() == 1) {
488-
auto paramTy = genericParamList->getParams()[0]
489-
->getInterfaceType()
490-
->getMetatypeInstanceType();
506+
// A single generic parameter.
507+
if (genericParamList->size() != 1)
508+
continue;
491509

492-
if (FD->getResultInterfaceType()->isEqual(paramTy))
493-
candidates.push_back(FD);
494-
}
510+
auto paramTy = genericParamList->getParams()[0]
511+
->getInterfaceType()
512+
->getMetatypeInstanceType();
513+
514+
// `decodeNextArgument` should return its generic parameter value
515+
if (!FD->getResultInterfaceType()->isEqual(paramTy))
516+
continue;
517+
518+
// Let's find out how many serialization requirements does this method cover
519+
// e.g. `Codable` is two requirements - `Encodable` and `Decodable`.
520+
unsigned numSerializationReqsCovered = llvm::count_if(
521+
FD->getGenericRequirements(), [&](const Requirement &requirement) {
522+
if (!(requirement.getFirstType()->isEqual(paramTy) &&
523+
requirement.getKind() == RequirementKind::Conformance))
524+
return 0;
525+
526+
return serializationReqs.count(requirement.getProtocolDecl()) ? 1 : 0;
527+
});
528+
529+
// If the current method covers all of the serialization requirements,
530+
// it's a match. Note that it might also have other requirements, but
531+
// we let that go as long as there are no two candidates that differ
532+
// only in generic requirements.
533+
if (numSerializationReqsCovered == serializationReqs.size())
534+
candidates.push_back(FD);
495535
}
496536

497537
// Type-checker should reject any definition of invocation decoder

0 commit comments

Comments
 (0)