Skip to content

Commit 2d92d4c

Browse files
committed
[Distributed] distributed func checks now work in protocols & dont crash
1 parent 0dae896 commit 2d92d4c

19 files changed

+417
-37
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4791,6 +4791,9 @@ ERROR(distributed_actor_func_static,none,
47914791
ERROR(distributed_actor_func_not_in_distributed_actor,none,
47924792
"'distributed' method can only be declared within 'distributed actor'",
47934793
())
4794+
ERROR(distributed_method_requirement_must_be_async_throws,none, // FIXME(distributed): this is an implementation limitation we should lift
4795+
"'distributed' protocol requirement %0 must currently be declared explicitly 'async throws'",
4796+
(DeclName))
47944797
ERROR(distributed_actor_user_defined_special_property,none,
47954798
"property %0 cannot be defined explicitly, as it conflicts with "
47964799
"distributed actor synthesized stored property",

include/swift/AST/DistributedDecl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ Type getDistributedActorSystemType(NominalTypeDecl *actor);
4747
/// Determine the `ID` type for the given actor.
4848
Type getDistributedActorIDType(NominalTypeDecl *actor);
4949

50+
/// Similar to `getDistributedSerializationRequirementType`, however, from the
51+
/// perspective of a concrete function. This way we're able to get the
52+
/// serialization requirement for specific members, also in protocols.
53+
Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member);
54+
5055
/// Get specific 'SerializationRequirement' as defined in 'nominal'
5156
/// type, which must conform to the passed 'protocol' which is expected
5257
/// to require the 'SerializationRequirement'.

include/swift/SIL/SILDeclRef.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,7 @@ struct SILDeclRef {
414414
defaultArgIndex,
415415
pointer.get<AutoDiffDerivativeFunctionIdentifier *>());
416416
}
417-
/// Returns the distributed entry point corresponding to the same
418-
/// decl.
417+
/// Returns the distributed entry point corresponding to the same decl.
419418
SILDeclRef asDistributed(bool distributed = true) const {
420419
return SILDeclRef(loc.getOpaqueValue(), kind,
421420
/*foreign=*/false,

lib/AST/DistributedDecl.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,37 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member)
9696
llvm_unreachable("Unable to fetch ActorSystem type!");
9797
}
9898

99+
Type swift::getConcreteReplacementForMemberSerializationRequirement(
100+
ValueDecl *member) {
101+
auto &C = member->getASTContext();
102+
auto *DC = member->getDeclContext();
103+
auto DA = C.getDistributedActorDecl();
104+
105+
// === When declared inside an actor, we can get the type directly
106+
if (auto classDecl = DC->getSelfClassDecl()) {
107+
return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl());
108+
}
109+
110+
/// === Maybe the value is declared in a protocol?
111+
if (auto protocol = DC->getSelfProtocolDecl()) {
112+
GenericSignature signature;
113+
if (auto *genericContext = member->getAsGenericContext()) {
114+
signature = genericContext->getGenericSignature();
115+
} else {
116+
signature = DC->getGenericSignatureOfContext();
117+
}
118+
119+
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
120+
->getDeclaredInterfaceType();
121+
122+
// Note that this may be null, e.g. if we're a distributed func inside
123+
// a protocol that did not declare a specific actor system requirement.
124+
return signature->getConcreteType(SerReqAssocType);
125+
}
126+
127+
llvm_unreachable("Unable to fetch ActorSystem type!");
128+
}
129+
99130
Type swift::getDistributedActorSystemType(NominalTypeDecl *actor) {
100131
assert(!dyn_cast<ProtocolDecl>(actor) &&
101132
"Use getConcreteReplacementForProtocolActorSystemType instead to get"

lib/SILGen/SILGenPoly.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4439,15 +4439,22 @@ void SILGenFunction::emitProtocolWitness(
44394439
SmallVector<ManagedValue, 8> origParams;
44404440
collectThunkParams(loc, origParams);
44414441

4442-
// If the witness is isolated to a distributed actor, but the requirement is
4443-
// not, go through the distributed thunk.
44444442
if (witness.hasDecl() &&
4445-
getActorIsolation(witness.getDecl()).isDistributedActor() &&
4446-
requirement.hasDecl() &&
4447-
!getActorIsolation(requirement.getDecl()).isDistributedActor()) {
4448-
witness = SILDeclRef(
4449-
cast<AbstractFunctionDecl>(witness.getDecl())->getDistributedThunk())
4450-
.asDistributed();
4443+
getActorIsolation(witness.getDecl()).isDistributedActor()) {
4444+
// We witness protocol requirements using the distributed thunk, when:
4445+
// - the witness is isolated to a distributed actor, but the requirement is not
4446+
// - the requirement is a distributed func, and therefore can only be witnessed
4447+
// by a distributed func; we handle this by witnessing the requirement with the thunk
4448+
// FIXME(distributed): this limits us to only allow distributed explicitly throwing async requirements... we need to fix this somehow.
4449+
if (requirement.hasDecl()) {
4450+
if ((!getActorIsolation(requirement.getDecl()).isDistributedActor()) ||
4451+
(isa<FuncDecl>(requirement.getDecl()) &&
4452+
witness.getFuncDecl()->isDistributed())) {
4453+
auto thunk = cast<AbstractFunctionDecl>(witness.getDecl())
4454+
->getDistributedThunk();
4455+
witness = SILDeclRef(thunk).asDistributed();
4456+
}
4457+
}
44514458
} else if (enterIsolation) {
44524459
// If we are supposed to enter the actor, do so now by hopping to the
44534460
// actor.

lib/Sema/CodeSynthesisDistributedActor.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,8 +648,13 @@ static FuncDecl *createDistributedThunkFunction(FuncDecl *func) {
648648
auto &C = func->getASTContext();
649649
auto DC = func->getDeclContext();
650650

651-
auto systemTy = getConcreteReplacementForProtocolActorSystemType(func);
652-
assert(systemTy &&
651+
// NOTE: So we don't need a thunk in the protocol, we should call the underlying
652+
// thing instead, which MUST have a thunk, since it must be a distributed func as well...
653+
if (dyn_cast<ProtocolDecl>(DC)) {
654+
return nullptr;
655+
}
656+
657+
assert(getConcreteReplacementForProtocolActorSystemType(func) &&
653658
"Thunk synthesis must have concrete actor system type available");
654659

655660
DeclName thunkName = func->getName();
@@ -788,7 +793,6 @@ FuncDecl *GetDistributedThunkRequest::evaluate(
788793
return nullptr;
789794

790795
auto &C = distributedTarget->getASTContext();
791-
auto DC = distributedTarget->getDeclContext();
792796

793797
if (!getConcreteReplacementForProtocolActorSystemType(distributedTarget)) {
794798
// Don't synthesize thunks, unless there is a *concrete* ActorSystem.
@@ -812,9 +816,6 @@ FuncDecl *GetDistributedThunkRequest::evaluate(
812816
if (!C.getLoadedModule(C.Id_Distributed))
813817
return nullptr;
814818

815-
auto nominal = DC->getSelfNominalTypeDecl(); // NOTE: Always from DC
816-
assert(nominal);
817-
818819
// --- Prepare the "distributed thunk" which does the "maybe remote" dance:
819820
return createDistributedThunkFunction(func);
820821
}

lib/Sema/TypeCheckAttr.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5900,16 +5900,32 @@ void AttributeChecker::visitDistributedActorAttr(DistributedActorAttr *attr) {
59005900

59015901
// distributed func must be declared inside an distributed actor
59025902
auto selfTy = dc->getSelfTypeInContext();
5903+
5904+
auto *protoDecl = dc->getSelfProtocolDecl();
59035905
if (!selfTy->isDistributedActor()) {
59045906
auto diagnostic = diagnoseAndRemoveAttr(
59055907
attr, diag::distributed_actor_func_not_in_distributed_actor);
59065908

5907-
if (auto *protoDecl = dc->getSelfProtocolDecl()) {
5909+
if (protoDecl) {
59085910
diagnoseDistributedFunctionInNonDistributedActorProtocol(protoDecl,
59095911
diagnostic);
59105912
}
59115913
return;
59125914
}
5915+
5916+
if (isa<ProtocolDecl>(dc)) {
5917+
if (!funcDecl->hasAsync() || !funcDecl->hasThrows()) {
5918+
auto diag = funcDecl->diagnose(diag::distributed_method_requirement_must_be_async_throws,
5919+
funcDecl->getName());
5920+
if (!funcDecl->hasAsync()) {
5921+
diag.fixItInsertAfter(funcDecl->getThrowsLoc(), " async");
5922+
}
5923+
if (!funcDecl->hasThrows()) {
5924+
diag.fixItInsertAfter(funcDecl->getThrowsLoc(), " throws");
5925+
}
5926+
return;
5927+
}
5928+
}
59135929
}
59145930
}
59155931

lib/Sema/TypeCheckDeclPrimary.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "TypeCheckAvailability.h"
2424
#include "TypeCheckConcurrency.h"
2525
#include "TypeCheckDecl.h"
26+
#include "TypeCheckDistributed.h"
2627
#include "TypeCheckObjC.h"
2728
#include "TypeCheckType.h"
2829
#include "TypeChecker.h"
@@ -33,6 +34,7 @@
3334
#include "swift/AST/AccessScope.h"
3435
#include "swift/AST/ExistentialLayout.h"
3536
#include "swift/AST/Expr.h"
37+
#include "swift/AST/DistributedDecl.h"
3638
#include "swift/AST/ForeignErrorConvention.h"
3739
#include "swift/AST/GenericEnvironment.h"
3840
#include "swift/AST/Initializer.h"
@@ -2875,6 +2877,7 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
28752877
}
28762878

28772879
TypeChecker::checkDeclAttributes(FD);
2880+
TypeChecker::checkDistributedFunc(FD);
28782881

28792882
if (!checkOverrides(FD)) {
28802883
// If a method has an 'override' keyword but does not

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,25 @@ bool CheckDistributedFunctionRequest::evaluate(
503503
serializationRequirements = getDistributedSerializationRequirementProtocols(
504504
getDistributedActorSystemType(actor)->getAnyNominal(),
505505
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
506+
} else if (isa<ProtocolDecl>(DC)) {
507+
if (auto seqReqTy =
508+
getConcreteReplacementForMemberSerializationRequirement(func)) {
509+
auto seqReqTyDes = seqReqTy->castTo<ExistentialType>()->getConstraintType()->getDesugaredType();
510+
for (auto req : flattenDistributedSerializationTypeToRequiredProtocols(seqReqTyDes)) {
511+
serializationRequirements.insert(req);
512+
}
513+
}
514+
515+
// The distributed actor constrained protocol has no serialization requirements
516+
// or actor system defined, so these will only be enforced, by implementations
517+
// of DAs conforming to it, skip checks here.
518+
if (serializationRequirements.empty()) {
519+
return false;
520+
}
506521
} else {
507-
llvm_unreachable("Cannot handle types other than extensions and actor "
508-
"declarations in distributed function checking.");
522+
llvm_unreachable("Distributed function detected in type other than extension, "
523+
"distributed actor, or protocol! This should not be possible "
524+
", please file a bug.");
509525
}
510526

511527
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
@@ -653,12 +669,23 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal
653669
// If applicable, this will create the default 'init(transport:)' initializer
654670
(void)nominal->getDefaultInitializer();
655671

672+
656673
for (auto member : nominal->getMembers()) {
657674
// --- Ensure all thunks
658675
if (auto func = dyn_cast<AbstractFunctionDecl>(member)) {
659676
if (!func->isDistributed())
660677
continue;
661678

679+
if (!isa<ProtocolDecl>(nominal)) {
680+
auto systemTy = getConcreteReplacementForProtocolActorSystemType(func);
681+
if (!systemTy || systemTy->hasError()) {
682+
nominal->diagnose(
683+
diag::distributed_actor_conformance_missing_system_type,
684+
nominal->getName());
685+
return;
686+
}
687+
}
688+
662689
if (auto thunk = func->getDistributedThunk()) {
663690
SF->DelayedFunctions.push_back(thunk);
664691
}
@@ -676,6 +703,13 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal
676703
(void)nominal->getDistributedActorIDProperty();
677704
}
678705

706+
void TypeChecker::checkDistributedFunc(FuncDecl *func) {
707+
if (!func->isDistributed())
708+
return;
709+
710+
swift::checkDistributedFunction(func);
711+
}
712+
679713
llvm::SmallPtrSet<ProtocolDecl *, 2>
680714
swift::getDistributedSerializationRequirementProtocols(
681715
NominalTypeDecl *nominal, ProtocolDecl *protocol) {

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1401,9 +1401,10 @@ bool WitnessChecker::findBestWitness(
14011401
attempt = static_cast<Attempt>(attempt + 1)) {
14021402
SmallVector<ValueDecl *, 4> witnesses;
14031403
switch (attempt) {
1404-
case Regular:
1404+
case Regular: {
14051405
witnesses = lookupValueWitnesses(requirement, ignoringNames);
14061406
break;
1407+
}
14071408
case OperatorsFromOverlay: {
14081409
// If we have a Clang declaration, the matching operator might be in the
14091410
// overlay for that module.

lib/Sema/TypeChecker.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,9 @@ diagnosePotentialOpaqueTypeUnavailability(SourceRange ReferenceRange,
11011101
/// Type check a 'distributed actor' declaration.
11021102
void checkDistributedActor(SourceFile *SF, NominalTypeDecl *decl);
11031103

1104+
/// Type check a single 'distributed func' declaration.
1105+
void checkDistributedFunc(FuncDecl *func);
1106+
11041107
void checkConcurrencyAvailability(SourceRange ReferenceRange,
11051108
const DeclContext *ReferenceDC);
11061109

0 commit comments

Comments
 (0)