Skip to content

Commit 9d56965

Browse files
committed
Use more of getConcreteReplacementForMemberSerializationRequirement
1 parent 14b5d5b commit 9d56965

File tree

3 files changed

+94
-114
lines changed

3 files changed

+94
-114
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ getDistributedSerializationRequirementProtocols(
9797
/// If so, we can emit slightly nicer diagnostics.
9898
bool checkDistributedSerializationRequirementIsExactlyCodable(
9999
ASTContext &C,
100-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements);
100+
Type type);
101101

102102
/// Get the `SerializationRequirement`, explode it into the specific
103103
/// protocol requirements and insert them into `requirements`.

lib/AST/DistributedDecl.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,17 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement(
106106
return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl());
107107
}
108108

109-
/// === Maybe the value is declared in a protocol?
110-
if (auto protocol = DC->getSelfProtocolDecl()) {
109+
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
110+
->getDeclaredInterfaceType();
111+
112+
if (DC->getSelfProtocolDecl() || isa<ExtensionDecl>(DC)) {
111113
GenericSignature signature;
112114
if (auto *genericContext = member->getAsGenericContext()) {
113115
signature = genericContext->getGenericSignature();
114116
} else {
115117
signature = DC->getGenericSignatureOfContext();
116118
}
117119

118-
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
119-
->getDeclaredInterfaceType();
120-
121120
// Note that this may be null, e.g. if we're a distributed func inside
122121
// a protocol that did not declare a specific actor system requirement.
123122
return signature->getConcreteType(SerReqAssocType);
@@ -355,15 +354,24 @@ swift::getDistributedSerializationRequirements(
355354

356355
bool swift::checkDistributedSerializationRequirementIsExactlyCodable(
357356
ASTContext &C,
358-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements) {
357+
Type type) {
358+
if (!type)
359+
return false;
360+
361+
if (type->hasError())
362+
return false;
363+
359364
auto encodable = C.getProtocol(KnownProtocolKind::Encodable);
360365
auto decodable = C.getProtocol(KnownProtocolKind::Decodable);
361366

362-
if (allRequirements.size() != 2)
367+
auto layout = type->getExistentialLayout();
368+
auto protocols = layout.getProtocols();
369+
370+
if (protocols.size() != 2)
363371
return false;
364372

365-
return allRequirements.count(encodable) &&
366-
allRequirements.count(decodable);
373+
return std::count(protocols.begin(), protocols.end(), encodable) == 1 &&
374+
std::count(protocols.begin(), protocols.end(), decodable) == 1;
367375
}
368376

369377
/******************************************************************************/

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 76 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -385,10 +385,13 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
385385

386386
static bool checkDistributedTargetResultType(
387387
ModuleDecl *module, ValueDecl *valueDecl,
388-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &serializationRequirements,
388+
Type serializationRequirement,
389389
bool diagnose) {
390390
auto &C = valueDecl->getASTContext();
391391

392+
if (!serializationRequirement || serializationRequirement->hasError())
393+
return false; // error of the type would be diagnosed elsewhere
394+
392395
Type resultType;
393396
if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
394397
resultType = func->mapTypeIntoContext(func->getResultInterfaceType());
@@ -403,36 +406,39 @@ static bool checkDistributedTargetResultType(
403406

404407
auto isCodableRequirement =
405408
checkDistributedSerializationRequirementIsExactlyCodable(
406-
C, serializationRequirements);
407-
408-
for(auto serializationReq : serializationRequirements) {
409-
auto conformance =
410-
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
411-
if (conformance.isInvalid()) {
412-
if (diagnose) {
413-
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
414-
"Codable" : // Codable is a typealias, easier to diagnose like that
415-
serializationReq->getNameStr();
416-
417-
auto diag = valueDecl->diagnose(
418-
diag::distributed_actor_target_result_not_codable,
419-
resultType,
420-
valueDecl,
421-
conformanceToSuggest
422-
);
423-
424-
if (isCodableRequirement) {
425-
if (auto resultNominalType = resultType->getAnyNominal()) {
426-
addCodableFixIt(resultNominalType, diag);
409+
C, serializationRequirement);
410+
411+
if (serializationRequirement && !serializationRequirement->hasError()) {
412+
auto srl = serializationRequirement->getExistentialLayout();
413+
for (auto serializationReq: srl.getProtocols()) {
414+
auto conformance =
415+
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
416+
if (conformance.isInvalid()) {
417+
if (diagnose) {
418+
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
419+
"Codable" : // Codable is a typealias, easier to diagnose like that
420+
serializationReq->getNameStr();
421+
422+
auto diag = valueDecl->diagnose(
423+
diag::distributed_actor_target_result_not_codable,
424+
resultType,
425+
valueDecl,
426+
conformanceToSuggest
427+
);
428+
429+
if (isCodableRequirement) {
430+
if (auto resultNominalType = resultType->getAnyNominal()) {
431+
addCodableFixIt(resultNominalType, diag);
432+
}
427433
}
428-
}
429-
} // end if: diagnose
430-
431-
return true;
434+
} // end if: diagnose
435+
436+
return true;
437+
}
432438
}
433439
}
434440

435-
return false;
441+
return false;
436442
}
437443

438444
bool swift::checkDistributedActorSystem(const NominalTypeDecl *system) {
@@ -503,74 +509,35 @@ bool CheckDistributedFunctionRequest::evaluate(
503509
if (!C.getLoadedModule(C.Id_Distributed))
504510
return true;
505511

506-
// === All parameters and the result type must conform
507-
// SerializationRequirement
508-
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements;
509-
if (auto extension = dyn_cast<ExtensionDecl>(DC)) {
510-
auto actorOrProtocol = extension->getExtendedNominal();
511-
if (auto actor = dyn_cast<ClassDecl>(actorOrProtocol)) {
512-
assert(actor->isAnyActor());
513-
serializationRequirements = getDistributedSerializationRequirementProtocols(
514-
getDistributedActorSystemType(actor)->getAnyNominal(),
515-
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
516-
} else if (auto protocol = dyn_cast<ProtocolDecl>(actorOrProtocol)) {
517-
extractDistributedSerializationRequirements(
518-
C, protocol->getGenericRequirements(),
519-
/*into=*/serializationRequirements);
520-
extractDistributedSerializationRequirements(
521-
C, extension->getGenericRequirements(),
522-
/*into=*/serializationRequirements);
523-
} else {
524-
// ignore
525-
}
526-
} else if (auto actor = dyn_cast<ClassDecl>(DC)) {
527-
serializationRequirements = getDistributedSerializationRequirementProtocols(
528-
getDistributedActorSystemType(actor)->getAnyNominal(),
529-
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
530-
} else if (isa<ProtocolDecl>(DC)) {
531-
if (auto seqReqTy =
532-
getConcreteReplacementForMemberSerializationRequirement(func)) {
533-
auto layout = seqReqTy->getExistentialLayout();
534-
for (auto req : layout.getProtocols()) {
535-
serializationRequirements.insert(req);
536-
}
537-
}
538-
539-
// The distributed actor constrained protocol has no serialization requirements
540-
// or actor system defined, so these will only be enforced, by implementations
541-
// of DAs conforming to it, skip checks here.
542-
if (serializationRequirements.empty()) {
543-
return false;
544-
}
545-
} else {
546-
llvm_unreachable("Distributed function detected in type other than extension, "
547-
"distributed actor, or protocol! This should not be possible "
548-
", please file a bug.");
549-
}
550-
551-
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
552-
auto serializationRequirementIsCodable =
553-
checkDistributedSerializationRequirementIsExactlyCodable(
554-
C, serializationRequirements);
555-
556-
for (auto param : *func->getParameters()) {
557-
// --- Check parameters for 'Codable' conformance
558-
auto paramTy = func->mapTypeIntoContext(param->getInterfaceType());
559-
560-
for (auto req : serializationRequirements) {
561-
if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) {
562-
auto diag = func->diagnose(
563-
diag::distributed_actor_func_param_not_codable,
564-
param->getArgumentName().str(), param->getInterfaceType(),
565-
func->getDescriptiveKind(),
566-
serializationRequirementIsCodable ? "Codable"
567-
: req->getNameStr());
568-
569-
if (auto paramNominalTy = paramTy->getAnyNominal()) {
570-
addCodableFixIt(paramNominalTy, diag);
571-
} // else, no nominal type to suggest the fixit for, e.g. a closure
572-
573-
return true;
512+
Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement(func);
513+
for (auto param: *func->getParameters()) {
514+
515+
// --- Check the parameter conforming to serialization requirements
516+
if (serializationReqType && !serializationReqType->hasError()) {
517+
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
518+
auto serializationRequirementIsCodable =
519+
checkDistributedSerializationRequirementIsExactlyCodable(
520+
C, serializationReqType);
521+
522+
// --- Check parameters for 'SerializationRequirement' conformance
523+
auto paramTy = func->mapTypeIntoContext(param->getInterfaceType());
524+
525+
auto srl = serializationReqType->getExistentialLayout();
526+
for (auto req: srl.getProtocols()) {
527+
if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) {
528+
auto diag = func->diagnose(
529+
diag::distributed_actor_func_param_not_codable,
530+
param->getArgumentName().str(), param->getInterfaceType(),
531+
func->getDescriptiveKind(),
532+
serializationRequirementIsCodable ? "Codable"
533+
: req->getNameStr());
534+
535+
if (auto paramNominalTy = paramTy->getAnyNominal()) {
536+
addCodableFixIt(paramNominalTy, diag);
537+
} // else, no nominal type to suggest the fixit for, e.g. a closure
538+
539+
return true;
540+
}
574541
}
575542
}
576543

@@ -607,10 +574,12 @@ bool CheckDistributedFunctionRequest::evaluate(
607574
}
608575
}
609576

610-
// --- Result type must be either void or a codable type
611-
if (checkDistributedTargetResultType(module, func, serializationRequirements,
612-
/*diagnose=*/true)) {
613-
return true;
577+
if (serializationReqType && !serializationReqType->hasError()) {
578+
// --- Result type must be either void or a codable type
579+
if (checkDistributedTargetResultType(module, func, serializationReqType,
580+
/*diagnose=*/true)) {
581+
return true;
582+
}
614583
}
615584

616585
return false;
@@ -658,13 +627,15 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
658627
DC->getSelfNominalTypeDecl()->getDistributedActorSystemProperty();
659628
auto systemDecl = systemVar->getInterfaceType()->getAnyNominal();
660629

661-
auto serializationRequirements =
662-
getDistributedSerializationRequirementProtocols(
663-
systemDecl,
664-
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
630+
// auto serializationRequirements =
631+
// getDistributedSerializationRequirementProtocols(
632+
// systemDecl,
633+
// C.getProtocol(KnownProtocolKind::DistributedActorSystem));
634+
auto serializationRequirement =
635+
getConcreteReplacementForMemberSerializationRequirement(systemVar);
665636

666637
auto module = var->getModuleContext();
667-
if (checkDistributedTargetResultType(module, var, serializationRequirements, diagnose)) {
638+
if (checkDistributedTargetResultType(module, var, serializationRequirement, diagnose)) {
668639
return true;
669640
}
670641

@@ -771,6 +742,7 @@ bool TypeChecker::checkDistributedFunc(FuncDecl *func) {
771742
return swift::checkDistributedFunction(func);
772743
}
773744

745+
// TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks
774746
llvm::SmallPtrSet<ProtocolDecl *, 2>
775747
swift::getDistributedSerializationRequirementProtocols(
776748
NominalTypeDecl *nominal, ProtocolDecl *protocol) {

0 commit comments

Comments
 (0)