Skip to content

Commit 3d47f70

Browse files
committed
handle conformance requirement on extension in distributed funcs
1 parent 9d56965 commit 3d47f70

File tree

4 files changed

+70
-83
lines changed

4 files changed

+70
-83
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ Type getDistributedActorIDType(NominalTypeDecl *actor);
5050
/// Similar to `getDistributedSerializationRequirementType`, however, from the
5151
/// perspective of a concrete function. This way we're able to get the
5252
/// serialization requirement for specific members, also in protocols.
53-
Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member);
53+
Type getSerializationRequirementTypesForMember(
54+
ValueDecl *member, llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements);
5455

5556
/// Get specific 'SerializationRequirement' as defined in 'nominal'
5657
/// type, which must conform to the passed 'protocol' which is expected
@@ -114,17 +115,6 @@ getDistributedSerializationRequirements(
114115
ProtocolDecl *protocol,
115116
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos);
116117

117-
/// Given any set of generic requirements, locate those which are about the
118-
/// `SerializationRequirement`. Those need to be applied in the parameter and
119-
/// return type checking of distributed targets.
120-
void
121-
extractDistributedSerializationRequirements(
122-
ASTContext &C,
123-
ArrayRef<Requirement> allRequirements,
124-
llvm::SmallPtrSet<ProtocolDecl *, 2> &into);
125-
126-
}
127-
128118
// ==== ------------------------------------------------------------------------
129119

130120
#endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */

lib/AST/DistributedDecl.cpp

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member)
9595
llvm_unreachable("Unable to fetch ActorSystem type!");
9696
}
9797

98-
Type swift::getConcreteReplacementForMemberSerializationRequirement(
99-
ValueDecl *member) {
98+
Type swift::getSerializationRequirementTypesForMember(
99+
ValueDecl *member,
100+
llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements) {
100101
auto &C = member->getASTContext();
101102
auto *DC = member->getDeclContext();
102103
auto DA = C.getDistributedActorDecl();
@@ -117,6 +118,18 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement(
117118
signature = DC->getGenericSignatureOfContext();
118119
}
119120

121+
// Also store all `SerializationRequirement : SomeProtocol` requirements
122+
for (auto requirement: signature.getRequirements()) {
123+
if (requirement.getFirstType()->isEqual(SerReqAssocType) &&
124+
requirement.getKind() == RequirementKind::Conformance) {
125+
if (auto nominal = requirement.getSecondType()->getAnyNominal()) {
126+
if (auto protocol = dyn_cast<ProtocolDecl>(nominal)) {
127+
serializationRequirements.insert(protocol);
128+
}
129+
}
130+
}
131+
}
132+
120133
// Note that this may be null, e.g. if we're a distributed func inside
121134
// a protocol that did not declare a specific actor system requirement.
122135
return signature->getConcreteType(SerReqAssocType);
@@ -1222,33 +1235,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
12221235
return true;
12231236
}
12241237

1225-
void
1226-
swift::extractDistributedSerializationRequirements(
1227-
ASTContext &C,
1228-
ArrayRef<Requirement> allRequirements,
1229-
llvm::SmallPtrSet<ProtocolDecl *, 2> &into) {
1230-
auto DA = C.getDistributedActorDecl();
1231-
auto daSerializationReqAssocType =
1232-
DA->getAssociatedType(C.Id_SerializationRequirement);
1233-
1234-
for (auto req : allRequirements) {
1235-
// FIXME: Seems unprincipled
1236-
if (req.getKind() != RequirementKind::SameType &&
1237-
req.getKind() != RequirementKind::Conformance)
1238-
continue;
1239-
1240-
if (auto dependentMemberType =
1241-
req.getFirstType()->getAs<DependentMemberType>()) {
1242-
if (dependentMemberType->getAssocType() == daSerializationReqAssocType) {
1243-
auto layout = req.getSecondType()->getExistentialLayout();
1244-
for (auto p : layout.getProtocols()) {
1245-
serializationReqs.insert(p);
1246-
}
1247-
}
1248-
}
1249-
}
1250-
}
1251-
12521238
/******************************************************************************/
12531239
/********************** Distributed Functions *********************************/
12541240
/******************************************************************************/

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,16 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
386386
static bool checkDistributedTargetResultType(
387387
ModuleDecl *module, ValueDecl *valueDecl,
388388
Type serializationRequirement,
389+
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements,
389390
bool diagnose) {
390391
auto &C = valueDecl->getASTContext();
391392

392-
if (!serializationRequirement || serializationRequirement->hasError())
393+
if (serializationRequirement && serializationRequirement->hasError()) {
394+
return false;
395+
}
396+
if ((!serializationRequirement || serializationRequirement->hasError()) && serializationRequirements.empty()) {
393397
return false; // error of the type would be diagnosed elsewhere
398+
}
394399

395400
Type resultType;
396401
if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
@@ -404,37 +409,43 @@ static bool checkDistributedTargetResultType(
404409
if (resultType->isVoid())
405410
return false;
406411

412+
413+
// Collect extra "SerializationRequirement: SomeProtocol" requirements
414+
if (serializationRequirement && !serializationRequirement->hasError()) {
415+
auto srl = serializationRequirement->getExistentialLayout();
416+
for (auto s: srl.getProtocols()) {
417+
serializationRequirements.insert(s);
418+
}
419+
}
420+
407421
auto isCodableRequirement =
408422
checkDistributedSerializationRequirementIsExactlyCodable(
409423
C, serializationRequirement);
410424

411-
if (serializationRequirement && !serializationRequirement->hasError()) {
412-
auto srl = serializationRequirement->getExistentialLayout();
413-
for (auto serializationReq: srl.getProtocols()) {
414-
auto conformance =
415-
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
416-
if (conformance.isInvalid()) {
417-
if (diagnose) {
418-
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
419-
"Codable" : // Codable is a typealias, easier to diagnose like that
420-
serializationReq->getNameStr();
421-
422-
auto diag = valueDecl->diagnose(
423-
diag::distributed_actor_target_result_not_codable,
424-
resultType,
425-
valueDecl,
426-
conformanceToSuggest
427-
);
428-
429-
if (isCodableRequirement) {
430-
if (auto resultNominalType = resultType->getAnyNominal()) {
431-
addCodableFixIt(resultNominalType, diag);
432-
}
425+
for (auto serializationReq: serializationRequirements) {
426+
auto conformance =
427+
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
428+
if (conformance.isInvalid()) {
429+
if (diagnose) {
430+
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
431+
"Codable" : // Codable is a typealias, easier to diagnose like that
432+
serializationReq->getNameStr();
433+
434+
auto diag = valueDecl->diagnose(
435+
diag::distributed_actor_target_result_not_codable,
436+
resultType,
437+
valueDecl,
438+
conformanceToSuggest
439+
);
440+
441+
if (isCodableRequirement) {
442+
if (auto resultNominalType = resultType->getAnyNominal()) {
443+
addCodableFixIt(resultNominalType, diag);
433444
}
434-
} // end if: diagnose
445+
}
446+
} // end if: diagnose
435447

436-
return true;
437-
}
448+
return true;
438449
}
439450
}
440451

@@ -502,16 +513,16 @@ bool CheckDistributedFunctionRequest::evaluate(
502513
}
503514

504515
auto &C = func->getASTContext();
505-
auto DC = func->getDeclContext();
506516
auto module = func->getParentModule();
507517

508518
/// If no distributed module is available, then no reason to even try checks.
509519
if (!C.getLoadedModule(C.Id_Distributed))
510520
return true;
511521

512-
Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement(func);
513-
for (auto param: *func->getParameters()) {
522+
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements;
523+
Type serializationReqType = getSerializationRequirementTypesForMember(func, serializationRequirements);
514524

525+
for (auto param: *func->getParameters()) {
515526
// --- Check the parameter conforming to serialization requirements
516527
if (serializationReqType && !serializationReqType->hasError()) {
517528
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
@@ -574,12 +585,11 @@ bool CheckDistributedFunctionRequest::evaluate(
574585
}
575586
}
576587

577-
if (serializationReqType && !serializationReqType->hasError()) {
578-
// --- Result type must be either void or a codable type
579-
if (checkDistributedTargetResultType(module, func, serializationReqType,
580-
/*diagnose=*/true)) {
581-
return true;
582-
}
588+
// --- Result type must be either void or a serialization requirement conforming type
589+
if (checkDistributedTargetResultType(
590+
module, func, serializationReqType, serializationRequirements,
591+
/*diagnose=*/true)) {
592+
return true;
583593
}
584594

585595
return false;
@@ -627,15 +637,16 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
627637
DC->getSelfNominalTypeDecl()->getDistributedActorSystemProperty();
628638
auto systemDecl = systemVar->getInterfaceType()->getAnyNominal();
629639

630-
// auto serializationRequirements =
631-
// getDistributedSerializationRequirementProtocols(
632-
// systemDecl,
633-
// C.getProtocol(KnownProtocolKind::DistributedActorSystem));
640+
auto serializationRequirements =
641+
getDistributedSerializationRequirementProtocols(
642+
systemDecl,
643+
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
644+
634645
auto serializationRequirement =
635-
getConcreteReplacementForMemberSerializationRequirement(systemVar);
646+
getSerializationRequirementTypesForMember(systemVar, serializationRequirements);
636647

637648
auto module = var->getModuleContext();
638-
if (checkDistributedTargetResultType(module, var, serializationRequirement, diagnose)) {
649+
if (checkDistributedTargetResultType(module, var, serializationRequirement, serializationRequirements, diagnose)) {
639650
return true;
640651
}
641652

test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ extension NoSerializationRequirementYet
8282

8383
extension NoSerializationRequirementYet
8484
where SerializationRequirement: Codable {
85-
// expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Codable'}}
85+
// expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Decodable'}}
8686
distributed func test4() -> NotCodable {
8787
.init()
8888
}

0 commit comments

Comments
 (0)