Skip to content

Commit 8d3e7d9

Browse files
authored
[Distributed] ResultHandler.onReturn must be ad-hoc because SerializationRequirement (#41916)
* [Distributed] Invoke handler.onReturn ad hoc via ast synthesized func * reformat and cleanup * remove unused var
1 parent e508ce3 commit 8d3e7d9

37 files changed

+1093
-500
lines changed

all.txt

Lines changed: 242 additions & 0 deletions
Large diffs are not rendered by default.

include/swift/AST/Decl.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3500,12 +3500,17 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
35003500
/// Find, or potentially synthesize, the implicit 'id' property of this actor.
35013501
VarDecl *getDistributedActorIDProperty() const;
35023502

3503-
/// Find the 'RemoteCallTarget.init(_:)' initializer function
3503+
/// Find the 'RemoteCallTarget.init(_:)' initializer function.
35043504
ConstructorDecl* getDistributedRemoteCallTargetInitFunction() const;
35053505

3506-
/// Find the 'RemoteCallArgument(label:name:value:)' initializer function
3506+
/// Find the 'RemoteCallArgument(label:name:value:)' initializer function.
35073507
ConstructorDecl* getDistributedRemoteCallArgumentInitFunction() const;
35083508

3509+
/// Find the
3510+
/// 'DistributedActorSystem.invokeHandlerOnReturn(handler:value:metatype:)
3511+
/// function.
3512+
FuncDecl *getDistributedActorSystemInvokeHandlerOnReturnFunction() const;
3513+
35093514
/// Collect the set of protocols to which this type should implicitly
35103515
/// conform, such as AnyObject (for classes).
35113516
void getImplicitProtocols(SmallVectorImpl<ProtocolDecl *> &protocols);
@@ -4301,6 +4306,7 @@ enum class KnownDerivableProtocolKind : uint8_t {
43014306
Differentiable,
43024307
Actor,
43034308
DistributedActor,
4309+
DistributedActorSystem,
43044310
};
43054311

43064312
/// ProtocolDecl - A declaration of a protocol, for example:
@@ -5592,6 +5598,19 @@ class ParamDecl : public VarDecl {
55925598
/// Create a an identical copy of this ParamDecl.
55935599
static ParamDecl *clone(const ASTContext &Ctx, ParamDecl *PD);
55945600

5601+
static ParamDecl *
5602+
createImplicit(ASTContext &Context, SourceLoc specifierLoc,
5603+
SourceLoc argumentNameLoc, Identifier argumentName,
5604+
SourceLoc parameterNameLoc, Identifier parameterName,
5605+
Type interfaceType, DeclContext *Parent,
5606+
ParamSpecifier specifier = ParamSpecifier::Default);
5607+
5608+
static ParamDecl *
5609+
createImplicit(ASTContext &Context, Identifier argumentName,
5610+
Identifier parameterName, Type interfaceType,
5611+
DeclContext *Parent,
5612+
ParamSpecifier specifier = ParamSpecifier::Default);
5613+
55955614
/// Retrieve the argument (API) name for this function parameter.
55965615
Identifier getArgumentName() const {
55975616
return ArgumentNameAndFlags.getPointer();
@@ -6447,6 +6466,10 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
64476466
return getBodyKind() == BodyKind::TypeChecked;
64486467
}
64496468

6469+
bool isBodySILSynthesize() const {
6470+
return getBodyKind() == BodyKind::SILSynthesize;
6471+
}
6472+
64506473
bool isBodySkipped() const {
64516474
return getBodyKind() == BodyKind::Skipped;
64526475
}
@@ -6460,8 +6483,8 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
64606483
/// initialization factory. Such functions do not have a body that is
64616484
/// representable in the AST, so it must be synthesized during SILGen.
64626485
bool isDistributedActorFactory() const {
6463-
return getBodyKind() == BodyKind::SILSynthesize
6464-
&& getSILSynthesizeKind() == SILSynthesizeKind::DistributedActorFactory;
6486+
return getBodyKind() == BodyKind::SILSynthesize &&
6487+
getSILSynthesizeKind() == SILSynthesizeKind::DistributedActorFactory;
64656488
}
64666489

64676490
/// Determines whether this function is a 'remoteCall' function,

include/swift/AST/DistributedDecl.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,12 @@ Type getDistributedSerializationRequirementType(
5757
/// system.
5858
Type getDistributedActorSystemInvocationEncoderType(NominalTypeDecl *system);
5959

60+
/// Get the specific 'ResultHandler' type of a specific distributed actor
61+
/// system.
62+
Type getDistributedActorSystemResultHandlerType(NominalTypeDecl *system);
63+
6064
/// Get the 'ActorID' type of a specific distributed actor system.
61-
Type getDistributedActorSystemActorIDRequirementType(NominalTypeDecl *system);
65+
Type getDistributedActorSystemActorIDType(NominalTypeDecl *system);
6266

6367
/// Get the specific protocols that the `SerializationRequirement` specifies,
6468
/// and all parameters / return types of distributed targets must conform to.

include/swift/AST/KnownIdentifiers.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,11 @@ IDENTIFIER(assignID)
275275
IDENTIFIER(decodeNext)
276276
IDENTIFIER(doneRecording)
277277
IDENTIFIER(id)
278+
IDENTIFIER(metatype)
279+
IDENTIFIER(handler)
278280
IDENTIFIER(invocation)
279281
IDENTIFIER(invocationDecoder)
282+
IDENTIFIER(invokeHandlerOnReturn)
280283
IDENTIFIER(makeInvocationEncoder)
281284
IDENTIFIER(on)
282285
IDENTIFIER(onReturn)
@@ -286,9 +289,11 @@ IDENTIFIER(recordGenericSubstitution)
286289
IDENTIFIER(recordReturnType)
287290
IDENTIFIER(remoteCall)
288291
IDENTIFIER(remoteCallVoid)
292+
IDENTIFIER(resultBuffer)
289293
IDENTIFIER(resignID)
290294
IDENTIFIER(resolve)
291295
IDENTIFIER(returning)
296+
IDENTIFIER(ResultHandler)
292297
IDENTIFIER(system)
293298
IDENTIFIER(target)
294299
IDENTIFIER(throwing)

include/swift/AST/TypeCheckRequests.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,25 @@ class GetDistributedThunkRequest :
12691269
bool isCached() const { return true; }
12701270
};
12711271

1272+
/// Synthesize and return the 'invokeHandlerOnReturn' method for a concrete
1273+
/// actor system.
1274+
class GetDistributedActorSystemInvokeHandlerOnReturnRequest
1275+
: public SimpleRequest<
1276+
GetDistributedActorSystemInvokeHandlerOnReturnRequest,
1277+
FuncDecl *(NominalTypeDecl *), RequestFlags::Cached> {
1278+
public:
1279+
using SimpleRequest::SimpleRequest;
1280+
1281+
private:
1282+
friend SimpleRequest;
1283+
1284+
FuncDecl *evaluate(Evaluator &evaluator, NominalTypeDecl *system) const;
1285+
1286+
public:
1287+
// Caching
1288+
bool isCached() const { return true; }
1289+
};
1290+
12721291
/// Obtain the 'id' property of a 'distributed actor'.
12731292
class GetDistributedActorIDPropertyRequest :
12741293
public SimpleRequest<GetDistributedActorIDPropertyRequest,

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ SWIFT_REQUEST(TypeChecker, GetDistributedRemoteCallTargetInitFunctionRequest,
148148
SWIFT_REQUEST(TypeChecker, GetDistributedRemoteCallArgumentInitFunctionRequest,
149149
ConstructorDecl *(NominalTypeDecl *),
150150
Cached, NoLocationInfo)
151+
SWIFT_REQUEST(TypeChecker, GetDistributedActorSystemInvokeHandlerOnReturnRequest,
152+
FuncDecl *(NominalTypeDecl *),
153+
Cached, NoLocationInfo)
151154
SWIFT_REQUEST(TypeChecker, GetDistributedActorInvocationDecoderRequest,
152155
NominalTypeDecl *(NominalTypeDecl *),
153156
Cached, NoLocationInfo)

lib/AST/ASTDumper.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2625,11 +2625,11 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
26252625
printCommon(E, name) << ' ';
26262626
if (auto checkedCast = dyn_cast<CheckedCastExpr>(E))
26272627
OS << getCheckedCastKindName(checkedCast->getCastKind()) << ' ';
2628-
OS << "writtenType='";
2629-
if (GetTypeOfTypeRepr)
2630-
GetTypeOfTypeRepr(E->getCastTypeRepr()).print(OS);
2631-
else
2632-
E->getCastType().print(OS);
2628+
OS << "writtenType='";
2629+
if (GetTypeOfTypeRepr)
2630+
GetTypeOfTypeRepr(E->getCastTypeRepr()).print(OS);
2631+
else
2632+
E->getCastType().print(OS);
26332633
OS << "'\n";
26342634
printRec(E->getSubExpr());
26352635
PrintWithColorRAII(OS, ParenthesisColor) << ')';

lib/AST/Decl.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5626,6 +5626,8 @@ Optional<KnownDerivableProtocolKind>
56265626
return KnownDerivableProtocolKind::Actor;
56275627
case KnownProtocolKind::DistributedActor:
56285628
return KnownDerivableProtocolKind::DistributedActor;
5629+
case KnownProtocolKind::DistributedActorSystem:
5630+
return KnownDerivableProtocolKind::DistributedActorSystem;
56295631
default: return None;
56305632
}
56315633
}
@@ -6830,6 +6832,32 @@ ParamDecl *ParamDecl::clone(const ASTContext &Ctx, ParamDecl *PD) {
68306832
return Clone;
68316833
}
68326834

6835+
ParamDecl *
6836+
ParamDecl::createImplicit(ASTContext &Context, SourceLoc specifierLoc,
6837+
SourceLoc argumentNameLoc, Identifier argumentName,
6838+
SourceLoc parameterNameLoc, Identifier parameterName,
6839+
Type interfaceType, DeclContext *Parent,
6840+
ParamSpecifier specifier) {
6841+
auto decl =
6842+
new (Context) ParamDecl(specifierLoc, argumentNameLoc, argumentName,
6843+
parameterNameLoc, parameterName, Parent);
6844+
decl->setImplicit();
6845+
// implicit ParamDecls must have a specifier set
6846+
decl->setSpecifier(specifier);
6847+
decl->setInterfaceType(interfaceType);
6848+
return decl;
6849+
}
6850+
6851+
ParamDecl *ParamDecl::createImplicit(ASTContext &Context,
6852+
Identifier argumentName,
6853+
Identifier parameterName,
6854+
Type interfaceType, DeclContext *Parent,
6855+
ParamSpecifier specifier) {
6856+
return ParamDecl::createImplicit(Context, SourceLoc(), SourceLoc(),
6857+
argumentName, SourceLoc(), parameterName,
6858+
interfaceType, Parent, specifier);
6859+
}
6860+
68336861
/// Retrieve the type of 'self' for the given context.
68346862
Type DeclContext::getSelfTypeInContext() const {
68356863
assert(isTypeContext());

lib/AST/DistributedDecl.cpp

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,33 +116,55 @@ Type swift::getDistributedActorIDType(NominalTypeDecl *actor) {
116116
return C.getAssociatedTypeOfDistributedSystemOfActor(actor, C.Id_ActorID);
117117
}
118118

119-
Type swift::getDistributedActorSystemActorIDRequirementType(NominalTypeDecl *system) {
119+
Type swift::getDistributedActorSystemActorIDType(NominalTypeDecl *system) {
120120
assert(!system->isDistributedActor());
121121
auto &ctx = system->getASTContext();
122122

123-
auto protocol = ctx.getDistributedActorSystemDecl();
124-
if (!protocol)
123+
auto DAS = ctx.getDistributedActorSystemDecl();
124+
if (!DAS)
125125
return Type();
126126

127127
// Dig out the serialization requirement type.
128128
auto module = system->getParentModule();
129129
Type selfType = system->getSelfInterfaceType();
130-
auto conformance = module->lookupConformance(selfType, protocol);
130+
auto conformance = module->lookupConformance(selfType, DAS);
131131
return conformance.getTypeWitnessByName(selfType, ctx.Id_ActorID);
132132
}
133133

134+
Type swift::getDistributedActorSystemResultHandlerType(
135+
NominalTypeDecl *system) {
136+
assert(!system->isDistributedActor());
137+
auto &ctx = system->getASTContext();
138+
139+
auto DAS = ctx.getDistributedActorSystemDecl();
140+
if (!DAS)
141+
return Type();
142+
143+
// Dig out the serialization requirement type.
144+
auto module = system->getParentModule();
145+
Type selfType = system->getSelfInterfaceType();
146+
auto conformance = module->lookupConformance(selfType, DAS);
147+
auto witness =
148+
conformance.getTypeWitnessByName(selfType, ctx.Id_ResultHandler);
149+
if (auto alias = dyn_cast<TypeAliasType>(witness.getPointer())) {
150+
return alias->getDecl()->getUnderlyingType();
151+
} else {
152+
return witness;
153+
}
154+
}
155+
134156
Type swift::getDistributedActorSystemInvocationEncoderType(NominalTypeDecl *system) {
135157
assert(!system->isDistributedActor());
136158
auto &ctx = system->getASTContext();
137159

138-
auto protocol = ctx.getDistributedActorSystemDecl();
139-
if (!protocol)
160+
auto DAS = ctx.getDistributedActorSystemDecl();
161+
if (!DAS)
140162
return Type();
141163

142164
// Dig out the serialization requirement type.
143165
auto module = system->getParentModule();
144166
Type selfType = system->getSelfInterfaceType();
145-
auto conformance = module->lookupConformance(selfType, protocol);
167+
auto conformance = module->lookupConformance(selfType, DAS);
146168
return conformance.getTypeWitnessByName(selfType, ctx.Id_InvocationEncoder);
147169
}
148170

@@ -494,8 +516,7 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
494516
if (actorIdReq.getKind() != RequirementKind::SameType) {
495517
return false;
496518
}
497-
auto expectedActorIdTy =
498-
getDistributedActorSystemActorIDRequirementType(systemNominal);
519+
auto expectedActorIdTy = getDistributedActorSystemActorIDType(systemNominal);
499520
if (!actorIdReq.getSecondType()->isEqual(expectedActorIdTy)) {
500521
return false;
501522
}
@@ -1036,6 +1057,11 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
10361057
auto &C = getASTContext();
10371058
auto module = getParentModule();
10381059

1060+
auto func = dyn_cast<FuncDecl>(this);
1061+
if (!func) {
1062+
return false;
1063+
}
1064+
10391065
// === Check base name
10401066
if (getBaseIdentifier() != C.Id_onReturn) {
10411067
return false;
@@ -1069,7 +1095,20 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
10691095
return false;
10701096
}
10711097

1072-
// TODO(distributed): check generics here
1098+
// --- Check number of generic parameters
1099+
auto genericParams = getGenericParams();
1100+
unsigned int expectedGenericParamNum = 1;
1101+
1102+
if (genericParams->size() != expectedGenericParamNum) {
1103+
return false;
1104+
}
1105+
1106+
// === Get the SerializationRequirement
1107+
SmallPtrSet<ProtocolDecl *, 2> requirementProtos;
1108+
if (!getDistributedSerializationRequirements(decoderNominal, decoderProto,
1109+
requirementProtos)) {
1110+
return false;
1111+
}
10731112

10741113
// === Check all parameters
10751114
auto params = getParameters();
@@ -1083,11 +1122,27 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
10831122
return false;
10841123
}
10851124

1086-
auto func = dyn_cast<FuncDecl>(this);
1087-
if (!func) {
1125+
// === Check generic parameters in detail
1126+
// --- Check: Argument: SerializationRequirement
1127+
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
1128+
auto argumentType = func->mapTypeIntoContext(valueParam->getInterfaceType())
1129+
->getMetatypeInstanceType()
1130+
->getDesugaredType();
1131+
auto resultParamType = func->mapTypeIntoContext(
1132+
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
1133+
// The result of the function must be the `Res` generic argument.
1134+
if (!argumentType->isEqual(resultParamType)) {
10881135
return false;
10891136
}
10901137

1138+
for (auto requirementProto : requirementProtos) {
1139+
auto conformance =
1140+
module->lookupConformance(argumentType, requirementProto);
1141+
if (conformance.isInvalid()) {
1142+
return false;
1143+
}
1144+
}
1145+
10911146
if (!func->getResultInterfaceType()->isVoid()) {
10921147
return false;
10931148
}
@@ -1160,7 +1215,18 @@ NominalTypeDecl::getDistributedRemoteCallArgumentInitFunction() const {
11601215
auto mutableThis = const_cast<NominalTypeDecl *>(this);
11611216
return evaluateOrDefault(
11621217
getASTContext().evaluator,
1163-
GetDistributedRemoteCallArgumentInitFunctionRequest(mutableThis), nullptr);
1218+
GetDistributedRemoteCallArgumentInitFunctionRequest(mutableThis),
1219+
nullptr);
1220+
}
1221+
1222+
FuncDecl *
1223+
NominalTypeDecl::getDistributedActorSystemInvokeHandlerOnReturnFunction()
1224+
const {
1225+
auto mutableThis = const_cast<NominalTypeDecl *>(this);
1226+
return evaluateOrDefault(
1227+
getASTContext().evaluator,
1228+
GetDistributedActorSystemInvokeHandlerOnReturnRequest(mutableThis),
1229+
nullptr);
11641230
}
11651231

11661232
AbstractFunctionDecl *ASTContext::getRemoteCallOnDistributedActorSystem(

lib/SILGen/SILGenDistributed.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,10 @@ void SILGenFunction::emitDistributedActorReady(
252252
emitActorReadyCall(B, loc, borrowedSelf.getValue(), transport);
253253
}
254254

255+
// ==== ------------------------------------------------------------------------
255256
// MARK: remote instance initialization
256257

257-
/// Synthesize the distributed actor's identity (`id`) initialization:
258+
/// emit a call to the distributed actor system's resolve function:
258259
///
259260
/// \verbatim
260261
/// system.resolve(id:as:)
@@ -408,6 +409,7 @@ void SILGenFunction::emitDistributedActorFactory(FuncDecl *fd) { // TODO(distrib
408409
}
409410
}
410411

412+
// ==== ------------------------------------------------------------------------
411413
// MARK: system.resignID()
412414

413415
void SILGenFunction::emitDistributedActorSystemResignIDCall(

lib/SILGen/SILGenFunction.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2073,7 +2073,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
20732073

20742074
/// Given a function representing a distributed actor factory, emits the
20752075
/// corresponding SIL function for it.
2076-
void emitDistributedActorFactory(FuncDecl *fd);
2076+
void emitDistributedActorFactory(
2077+
FuncDecl *fd); // TODO(distributed): this is the "resolve"
20772078

20782079
/// Notify transport that actor has initialized successfully,
20792080
/// and is ready to receive messages.

0 commit comments

Comments
 (0)