Skip to content

Commit 61da992

Browse files
authored
Merge pull request #41799 from ktoso/wip-argumentNameRecording
[Distributed] Add name parameter to recordArgument for better interop
2 parents b258865 + 5a5d1ba commit 61da992

28 files changed

+320
-106
lines changed

include/swift/AST/Decl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3493,9 +3493,12 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
34933493
/// Find, or potentially synthesize, the implicit 'id' property of this actor.
34943494
VarDecl *getDistributedActorIDProperty() const;
34953495

3496-
/// Find the 'RemoteCallTarget.init(_mangledName:)' initializer function
3496+
/// Find the 'RemoteCallTarget.init(_:)' initializer function
34973497
ConstructorDecl* getDistributedRemoteCallTargetInitFunction() const;
34983498

3499+
/// Find the 'RemoteCallArgument(label:name:value:)' initializer function
3500+
ConstructorDecl* getDistributedRemoteCallArgumentInitFunction() const;
3501+
34993502
/// Collect the set of protocols to which this type should implicitly
35003503
/// conform, such as AnyObject (for classes).
35013504
void getImplicitProtocols(SmallVectorImpl<ProtocolDecl *> &protocols);

include/swift/AST/KnownSDKTypes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ KNOWN_SDK_TYPE_DECL(Distributed, DistributedActorSystem, ProtocolDecl, 0)
4949
KNOWN_SDK_TYPE_DECL(Distributed, DistributedTargetInvocationEncoder, ProtocolDecl, 0)
5050
KNOWN_SDK_TYPE_DECL(Distributed, DistributedTargetInvocationDecoder, ProtocolDecl, 0)
5151
KNOWN_SDK_TYPE_DECL(Distributed, RemoteCallTarget, StructDecl, 0)
52+
KNOWN_SDK_TYPE_DECL(Distributed, RemoteCallArgument, StructDecl, 1)
5253

5354
// String processing
5455
KNOWN_SDK_TYPE_DECL(StringProcessing, Regex, StructDecl, 1)

include/swift/AST/TypeCheckRequests.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,7 @@ class GetDistributedActorSystemPropertyRequest :
12101210
bool isCached() const { return true; }
12111211
};
12121212

1213-
/// Obtain the constructor of the RemoteCallTarget type.
1213+
/// Obtain the constructor of the 'RemoteCallTarget' type.
12141214
class GetDistributedRemoteCallTargetInitFunctionRequest :
12151215
public SimpleRequest<GetDistributedRemoteCallTargetInitFunctionRequest,
12161216
ConstructorDecl *(NominalTypeDecl *),
@@ -1229,6 +1229,25 @@ class GetDistributedRemoteCallTargetInitFunctionRequest :
12291229
bool isCached() const { return true; }
12301230
};
12311231

1232+
/// Obtain the constructor of the 'RemoteCallArgument' type.
1233+
class GetDistributedRemoteCallArgumentInitFunctionRequest :
1234+
public SimpleRequest<GetDistributedRemoteCallArgumentInitFunctionRequest,
1235+
ConstructorDecl *(NominalTypeDecl *),
1236+
RequestFlags::Cached> {
1237+
public:
1238+
using SimpleRequest::SimpleRequest;
1239+
1240+
private:
1241+
friend SimpleRequest;
1242+
1243+
ConstructorDecl *evaluate(Evaluator &evaluator,
1244+
NominalTypeDecl *nominal) const;
1245+
1246+
public:
1247+
// Caching
1248+
bool isCached() const { return true; }
1249+
};
1250+
12321251
/// Obtain the 'distributed thunk' for the passed-in function.
12331252
///
12341253
/// The thunk is responsible for invoking 'remoteCall' when invoked on a remote

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ SWIFT_REQUEST(TypeChecker, GetDistributedActorSystemPropertyRequest,
145145
SWIFT_REQUEST(TypeChecker, GetDistributedRemoteCallTargetInitFunctionRequest,
146146
ConstructorDecl *(NominalTypeDecl *),
147147
Cached, NoLocationInfo)
148+
SWIFT_REQUEST(TypeChecker, GetDistributedRemoteCallArgumentInitFunctionRequest,
149+
ConstructorDecl *(NominalTypeDecl *),
150+
Cached, NoLocationInfo)
148151
SWIFT_REQUEST(TypeChecker, GetDistributedActorInvocationDecoderRequest,
149152
NominalTypeDecl *(NominalTypeDecl *),
150153
Cached, NoLocationInfo)

lib/AST/DistributedDecl.cpp

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,11 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
581581
auto &C = getASTContext();
582582
auto module = getParentModule();
583583

584+
auto func = dyn_cast<FuncDecl>(this);
585+
if (!func) {
586+
return false;
587+
}
588+
584589
// === Check base name
585590
if (getBaseIdentifier() != C.Id_recordArgument) {
586591
return false;
@@ -614,6 +619,12 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
614619
return false;
615620
}
616621

622+
// --- must be mutating, if it is defined in a struct
623+
if (isa<StructDecl>(getDeclContext()) &&
624+
!func->isMutating()) {
625+
return false;
626+
}
627+
617628
// --- Check number of generic parameters
618629
auto genericParams = getGenericParams();
619630
unsigned int expectedGenericParamNum = 1;
@@ -639,56 +650,60 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
639650
return false;
640651
}
641652

642-
// --- Check parameter: _ argument
643-
auto argumentParam = params->get(0);
644-
if (!argumentParam->getArgumentName().is("")) {
645-
return false;
646-
}
647-
648-
// === Check generic parameters in detail
649-
// --- Check: Argument: SerializationRequirement
650653
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
651654

652-
auto sig = getGenericSignature();
653-
auto requirements = sig.getRequirements();
655+
// --- Check parameter: _ argument
656+
auto argumentParam = params->get(0);
657+
if (!argumentParam->getArgumentName().empty()) {
658+
return false;
659+
}
654660

655-
if (requirements.size() != expectedRequirementsNum) {
656-
return false;
657-
}
661+
auto argumentTy = argumentParam->getInterfaceType();
662+
auto argumentInContextTy = mapTypeIntoContext(argumentTy);
663+
if (argumentInContextTy->getAnyNominal() == C.getRemoteCallArgumentDecl()) {
664+
auto argGenericParams = argumentInContextTy->getStructOrBoundGenericStruct()
665+
->getGenericParams()->getParams();
666+
if (argGenericParams.size() != 1) {
667+
return false;
668+
}
658669

659-
// --- Check the expected requirements
660-
// --- all the Argument requirements ---
661-
// conforms_to: Argument Decodable
662-
// conforms_to: Argument Encodable
663-
// ...
670+
// the <Value> of the RemoteCallArgument<Value>
671+
auto remoteCallArgValueGenericTy =
672+
mapTypeIntoContext(argGenericParams[0]->getInterfaceType())
673+
->getDesugaredType()
674+
->getMetatypeInstanceType();
675+
// expected (the <Value> from the recordArgument<Value>)
676+
auto expectedGenericParamTy = mapTypeIntoContext(
677+
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
678+
679+
if (!remoteCallArgValueGenericTy->isEqual(expectedGenericParamTy)) {
680+
return false;
681+
}
682+
} else {
683+
return false;
684+
}
664685

665-
auto func = dyn_cast<FuncDecl>(this);
666-
if (!func) {
667-
return false;
668-
}
669686

670-
auto resultType = func->mapTypeIntoContext(argumentParam->getInterfaceType())
671-
->getDesugaredType();
672-
auto resultParamType = func->mapTypeIntoContext(
673-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
674-
// The result of the function must be the `Res` generic argument.
675-
if (!resultType->isEqual(resultParamType)) {
676-
return false;
677-
}
687+
auto sig = getGenericSignature();
688+
auto requirements = sig.getRequirements();
678689

679-
for (auto requirementProto : requirementProtos) {
680-
auto conformance = module->lookupConformance(resultType, requirementProto);
681-
if (conformance.isInvalid()) {
690+
if (requirements.size() != expectedRequirementsNum) {
682691
return false;
683692
}
684-
}
685693

686-
// === Check result type: Void
687-
if (!func->getResultInterfaceType()->isVoid()) {
688-
return false;
689-
}
694+
// --- Check the expected requirements
695+
// --- all the Argument requirements ---
696+
// e.g.
697+
// conforms_to: Argument Decodable
698+
// conforms_to: Argument Encodable
699+
// ...
690700

691-
return true;
701+
// === Check result type: Void
702+
if (!func->getResultInterfaceType()->isVoid()) {
703+
return false;
704+
}
705+
706+
return true;
692707
}
693708

694709
bool
@@ -879,8 +894,8 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
879894
}
880895

881896
// --- Check parameter: _ errorType
882-
auto argumentParam = params->get(0);
883-
if (!argumentParam->getArgumentName().is("")) {
897+
auto errorTypeParam = params->get(0);
898+
if (!errorTypeParam->getArgumentName().is("")) {
884899
return false;
885900
}
886901

@@ -1140,6 +1155,14 @@ NominalTypeDecl::getDistributedRemoteCallTargetInitFunction() const {
11401155
GetDistributedRemoteCallTargetInitFunctionRequest(mutableThis), nullptr);
11411156
}
11421157

1158+
ConstructorDecl *
1159+
NominalTypeDecl::getDistributedRemoteCallArgumentInitFunction() const {
1160+
auto mutableThis = const_cast<NominalTypeDecl *>(this);
1161+
return evaluateOrDefault(
1162+
getASTContext().evaluator,
1163+
GetDistributedRemoteCallArgumentInitFunctionRequest(mutableThis), nullptr);
1164+
}
1165+
11431166
AbstractFunctionDecl *ASTContext::getRemoteCallOnDistributedActorSystem(
11441167
NominalTypeDecl *actorOrSystem, bool isVoidReturn) const {
11451168
assert(actorOrSystem && "distributed actor (or system) decl must be provided");

lib/Sema/CodeSynthesisDistributedActor.cpp

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,67 @@ deriveBodyDistributed_thunk(AbstractFunctionDecl *thunk, void *context) {
269269
auto recordArgumentDeclRef = UnresolvedDeclRefExpr::createImplicit(
270270
C, recordArgumentDecl->getName());
271271

272-
auto recordArgArgsList = ArgumentList::forImplicitCallTo(
273-
recordArgumentDeclRef->getName(),
272+
auto argumentName = param->getArgumentName().str();
273+
LiteralExpr *argumentLabelArg;
274+
if (argumentName.empty()) {
275+
argumentLabelArg = new (C) NilLiteralExpr(sloc, implicit);
276+
} else {
277+
argumentLabelArg =
278+
new (C) StringLiteralExpr(argumentName, SourceRange(), implicit);
279+
}
280+
auto parameterName = param->getParameterName().str();
281+
282+
283+
// --- Prepare the RemoteCallArgument<Value> for the argument
284+
auto argumentVarName = C.getIdentifier("_" + parameterName.str());
285+
StructDecl *RCA = C.getRemoteCallArgumentDecl();
286+
VarDecl *callArgVar =
287+
new (C) VarDecl(/*isStatic=*/false, VarDecl::Introducer::Let, sloc,
288+
argumentVarName, thunk);
289+
callArgVar->setImplicit();
290+
callArgVar->setSynthesized();
291+
292+
Pattern *callArgPattern = NamedPattern::createImplicit(C, callArgVar);
293+
294+
auto remoteCallArgumentInitDecl =
295+
RCA->getDistributedRemoteCallArgumentInitFunction();
296+
auto boundRCAType = BoundGenericType::get(
297+
RCA, Type(), {thunk->mapTypeIntoContext(param->getInterfaceType())});
298+
auto remoteCallArgumentInitDeclRef =
299+
TypeExpr::createImplicit(boundRCAType, C);
300+
301+
auto initCallArgArgs = ArgumentList::forImplicitCallTo(
302+
DeclNameRef(remoteCallArgumentInitDecl->getEffectiveFullName()),
274303
{
275-
new (C) DeclRefExpr(
304+
// label:
305+
argumentLabelArg,
306+
// name:
307+
new (C) StringLiteralExpr(parameterName, SourceRange(), implicit),
308+
// _ argument:
309+
new (C) DeclRefExpr(
276310
ConcreteDeclRef(param), dloc, implicit,
277311
AccessSemantics::Ordinary,
278312
thunk->mapTypeIntoContext(param->getInterfaceType()))
313+
},
314+
C);
315+
316+
auto initCallArgCallExpr =
317+
CallExpr::createImplicit(C, remoteCallArgumentInitDeclRef, initCallArgArgs);
318+
initCallArgCallExpr->setImplicit();
319+
320+
auto callArgPB = PatternBindingDecl::createImplicit(
321+
C, StaticSpellingKind::None, callArgPattern, initCallArgCallExpr, thunk);
322+
323+
remoteBranchStmts.push_back(callArgPB);
324+
remoteBranchStmts.push_back(callArgVar);
325+
326+
/// --- Pass the argumentRepr to the recordArgument function
327+
auto recordArgArgsList = ArgumentList::forImplicitCallTo(
328+
recordArgumentDeclRef->getName(),
329+
{
330+
new (C) DeclRefExpr(
331+
ConcreteDeclRef(callArgVar), dloc, implicit,
332+
AccessSemantics::Ordinary)
279333
}, C);
280334

281335
auto tryRecordArgExpr = TryExpr::createImplicit(C, sloc,

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
305305
decl->getDescriptiveKind(), decl->getName(), identifier);
306306
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
307307
decl->getName(), identifier,
308-
"mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws\n");
308+
"mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws\n");
309309
anyMissingAdHocRequirements = true;
310310
}
311311
if (checkAdHocRequirementAccessControl(decl, Proto, recordArgumentDecl))
@@ -731,6 +731,48 @@ GetDistributedRemoteCallTargetInitFunctionRequest::evaluate(
731731
return nullptr;
732732
}
733733

734+
ConstructorDecl*
735+
GetDistributedRemoteCallArgumentInitFunctionRequest::evaluate(
736+
Evaluator &evaluator,
737+
NominalTypeDecl *nominal) const {
738+
auto &C = nominal->getASTContext();
739+
740+
// not via `ensureDistributedModuleLoaded` to avoid generating a warning,
741+
// we won't be emitting the offending decl after all.
742+
if (!C.getLoadedModule(C.Id_Distributed))
743+
return nullptr;
744+
745+
if (!nominal->getDeclaredInterfaceType()->isEqual(
746+
C.getRemoteCallArgumentType()))
747+
return nullptr;
748+
749+
for (auto value : nominal->getMembers()) {
750+
auto ctor = dyn_cast<ConstructorDecl>(value);
751+
if (!ctor)
752+
continue;
753+
754+
auto params = ctor->getParameters();
755+
if (params->size() != 3)
756+
return nullptr;
757+
758+
// --- param: label
759+
if (!params->get(0)->getArgumentName().is("label"))
760+
return nullptr;
761+
762+
// --- param: name
763+
if (!params->get(1)->getArgumentName().is("name"))
764+
return nullptr;
765+
766+
// --- param: value
767+
if (params->get(2)->getArgumentName() != C.Id_value)
768+
return nullptr;
769+
770+
return ctor;
771+
}
772+
773+
return nullptr;
774+
}
775+
734776
NominalTypeDecl *
735777
GetDistributedActorInvocationDecoderRequest::evaluate(Evaluator &evaluator,
736778
NominalTypeDecl *actor) const {

0 commit comments

Comments
 (0)