Skip to content

Commit 5a5d1ba

Browse files
committed
[Distributed] Implement RemoteCallArgument
1 parent 20f28f8 commit 5a5d1ba

26 files changed

+282
-102
lines changed

include/swift/AST/Decl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3493,9 +3493,12 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
34933493
/// Find, or potentially synthesize, the implicit 'id' property of this actor.
34943494
VarDecl *getDistributedActorIDProperty() const;
34953495

3496-
/// Find the 'RemoteCallTarget.init(_mangledName:)' initializer function
3496+
/// Find the 'RemoteCallTarget.init(_:)' initializer function
34973497
ConstructorDecl* getDistributedRemoteCallTargetInitFunction() const;
34983498

3499+
/// Find the 'RemoteCallArgument(label:name:value:)' initializer function
3500+
ConstructorDecl* getDistributedRemoteCallArgumentInitFunction() const;
3501+
34993502
/// Collect the set of protocols to which this type should implicitly
35003503
/// conform, such as AnyObject (for classes).
35013504
void getImplicitProtocols(SmallVectorImpl<ProtocolDecl *> &protocols);

include/swift/AST/KnownSDKTypes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ KNOWN_SDK_TYPE_DECL(Distributed, DistributedActorSystem, ProtocolDecl, 0)
4949
KNOWN_SDK_TYPE_DECL(Distributed, DistributedTargetInvocationEncoder, ProtocolDecl, 0)
5050
KNOWN_SDK_TYPE_DECL(Distributed, DistributedTargetInvocationDecoder, ProtocolDecl, 0)
5151
KNOWN_SDK_TYPE_DECL(Distributed, RemoteCallTarget, StructDecl, 0)
52+
KNOWN_SDK_TYPE_DECL(Distributed, RemoteCallArgument, StructDecl, 1)
5253

5354
// String processing
5455
KNOWN_SDK_TYPE_DECL(StringProcessing, Regex, StructDecl, 1)

include/swift/AST/TypeCheckRequests.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,7 @@ class GetDistributedActorSystemPropertyRequest :
12101210
bool isCached() const { return true; }
12111211
};
12121212

1213-
/// Obtain the constructor of the RemoteCallTarget type.
1213+
/// Obtain the constructor of the 'RemoteCallTarget' type.
12141214
class GetDistributedRemoteCallTargetInitFunctionRequest :
12151215
public SimpleRequest<GetDistributedRemoteCallTargetInitFunctionRequest,
12161216
ConstructorDecl *(NominalTypeDecl *),
@@ -1229,6 +1229,25 @@ class GetDistributedRemoteCallTargetInitFunctionRequest :
12291229
bool isCached() const { return true; }
12301230
};
12311231

1232+
/// Obtain the constructor of the 'RemoteCallArgument' type.
1233+
class GetDistributedRemoteCallArgumentInitFunctionRequest :
1234+
public SimpleRequest<GetDistributedRemoteCallArgumentInitFunctionRequest,
1235+
ConstructorDecl *(NominalTypeDecl *),
1236+
RequestFlags::Cached> {
1237+
public:
1238+
using SimpleRequest::SimpleRequest;
1239+
1240+
private:
1241+
friend SimpleRequest;
1242+
1243+
ConstructorDecl *evaluate(Evaluator &evaluator,
1244+
NominalTypeDecl *nominal) const;
1245+
1246+
public:
1247+
// Caching
1248+
bool isCached() const { return true; }
1249+
};
1250+
12321251
/// Obtain the 'distributed thunk' for the passed-in function.
12331252
///
12341253
/// The thunk is responsible for invoking 'remoteCall' when invoked on a remote

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ SWIFT_REQUEST(TypeChecker, GetDistributedActorSystemPropertyRequest,
145145
SWIFT_REQUEST(TypeChecker, GetDistributedRemoteCallTargetInitFunctionRequest,
146146
ConstructorDecl *(NominalTypeDecl *),
147147
Cached, NoLocationInfo)
148+
SWIFT_REQUEST(TypeChecker, GetDistributedRemoteCallArgumentInitFunctionRequest,
149+
ConstructorDecl *(NominalTypeDecl *),
150+
Cached, NoLocationInfo)
148151
SWIFT_REQUEST(TypeChecker, GetDistributedActorInvocationDecoderRequest,
149152
NominalTypeDecl *(NominalTypeDecl *),
150153
Cached, NoLocationInfo)

lib/AST/DistributedDecl.cpp

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,11 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
581581
auto &C = getASTContext();
582582
auto module = getParentModule();
583583

584+
auto func = dyn_cast<FuncDecl>(this);
585+
if (!func) {
586+
return false;
587+
}
588+
584589
// === Check base name
585590
if (getBaseIdentifier() != C.Id_recordArgument) {
586591
return false;
@@ -614,6 +619,12 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
614619
return false;
615620
}
616621

622+
// --- must be mutating, if it is defined in a struct
623+
if (isa<StructDecl>(getDeclContext()) &&
624+
!func->isMutating()) {
625+
return false;
626+
}
627+
617628
// --- Check number of generic parameters
618629
auto genericParams = getGenericParams();
619630
unsigned int expectedGenericParamNum = 1;
@@ -635,28 +646,43 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
635646

636647
// === Check all parameters
637648
auto params = getParameters();
638-
if (params->size() != 2) {
649+
if (params->size() != 1) {
639650
return false;
640651
}
641652

642-
// --- Check parameter: label
643-
auto labelParam = params->get(0);
644-
if (!labelParam->getArgumentName().is("name")) {
645-
return false;
646-
}
647-
if (!labelParam->getInterfaceType()->isEqual(C.getStringType())) {
653+
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
654+
655+
// --- Check parameter: _ argument
656+
auto argumentParam = params->get(0);
657+
if (!argumentParam->getArgumentName().empty()) {
648658
return false;
649659
}
650660

651-
// --- Check parameter: _ argument
652-
auto argumentParam = params->get(1);
653-
if (!argumentParam->getArgumentName().is("")) {
661+
auto argumentTy = argumentParam->getInterfaceType();
662+
auto argumentInContextTy = mapTypeIntoContext(argumentTy);
663+
if (argumentInContextTy->getAnyNominal() == C.getRemoteCallArgumentDecl()) {
664+
auto argGenericParams = argumentInContextTy->getStructOrBoundGenericStruct()
665+
->getGenericParams()->getParams();
666+
if (argGenericParams.size() != 1) {
667+
return false;
668+
}
669+
670+
// the <Value> of the RemoteCallArgument<Value>
671+
auto remoteCallArgValueGenericTy =
672+
mapTypeIntoContext(argGenericParams[0]->getInterfaceType())
673+
->getDesugaredType()
674+
->getMetatypeInstanceType();
675+
// expected (the <Value> from the recordArgument<Value>)
676+
auto expectedGenericParamTy = mapTypeIntoContext(
677+
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
678+
679+
if (!remoteCallArgValueGenericTy->isEqual(expectedGenericParamTy)) {
680+
return false;
681+
}
682+
} else {
654683
return false;
655684
}
656685

657-
// === Check generic parameters in detail
658-
// --- Check: Argument: SerializationRequirement
659-
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
660686

661687
auto sig = getGenericSignature();
662688
auto requirements = sig.getRequirements();
@@ -672,29 +698,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
672698
// conforms_to: Argument Encodable
673699
// ...
674700

675-
auto func = dyn_cast<FuncDecl>(this);
676-
if (!func) {
677-
return false;
678-
}
679-
680-
auto resultType =
681-
func->mapTypeIntoContext(argumentParam->getInterfaceType())
682-
->getDesugaredType();
683-
auto resultParamType = func->mapTypeIntoContext(
684-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
685-
// The result of the function must be the `Res` generic argument.
686-
if (!resultType->isEqual(resultParamType)) {
687-
return false;
688-
}
689-
690-
for (auto requirementProto : requirementProtos) {
691-
auto conformance =
692-
module->lookupConformance(resultType, requirementProto);
693-
if (conformance.isInvalid()) {
694-
return false;
695-
}
696-
}
697-
698701
// === Check result type: Void
699702
if (!func->getResultInterfaceType()->isVoid()) {
700703
return false;
@@ -1152,6 +1155,14 @@ NominalTypeDecl::getDistributedRemoteCallTargetInitFunction() const {
11521155
GetDistributedRemoteCallTargetInitFunctionRequest(mutableThis), nullptr);
11531156
}
11541157

1158+
ConstructorDecl *
1159+
NominalTypeDecl::getDistributedRemoteCallArgumentInitFunction() const {
1160+
auto mutableThis = const_cast<NominalTypeDecl *>(this);
1161+
return evaluateOrDefault(
1162+
getASTContext().evaluator,
1163+
GetDistributedRemoteCallArgumentInitFunctionRequest(mutableThis), nullptr);
1164+
}
1165+
11551166
AbstractFunctionDecl *ASTContext::getRemoteCallOnDistributedActorSystem(
11561167
NominalTypeDecl *actorOrSystem, bool isVoidReturn) const {
11571168
assert(actorOrSystem && "distributed actor (or system) decl must be provided");

lib/Sema/CodeSynthesisDistributedActor.cpp

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,66 @@ deriveBodyDistributed_thunk(AbstractFunctionDecl *thunk, void *context) {
270270
C, recordArgumentDecl->getName());
271271

272272
auto argumentName = param->getArgumentName().str();
273-
auto recordArgArgsList = ArgumentList::forImplicitCallTo(
274-
recordArgumentDeclRef->getName(),
273+
LiteralExpr *argumentLabelArg;
274+
if (argumentName.empty()) {
275+
argumentLabelArg = new (C) NilLiteralExpr(sloc, implicit);
276+
} else {
277+
argumentLabelArg =
278+
new (C) StringLiteralExpr(argumentName, SourceRange(), implicit);
279+
}
280+
auto parameterName = param->getParameterName().str();
281+
282+
283+
// --- Prepare the RemoteCallArgument<Value> for the argument
284+
auto argumentVarName = C.getIdentifier("_" + parameterName.str());
285+
StructDecl *RCA = C.getRemoteCallArgumentDecl();
286+
VarDecl *callArgVar =
287+
new (C) VarDecl(/*isStatic=*/false, VarDecl::Introducer::Let, sloc,
288+
argumentVarName, thunk);
289+
callArgVar->setImplicit();
290+
callArgVar->setSynthesized();
291+
292+
Pattern *callArgPattern = NamedPattern::createImplicit(C, callArgVar);
293+
294+
auto remoteCallArgumentInitDecl =
295+
RCA->getDistributedRemoteCallArgumentInitFunction();
296+
auto boundRCAType = BoundGenericType::get(
297+
RCA, Type(), {thunk->mapTypeIntoContext(param->getInterfaceType())});
298+
auto remoteCallArgumentInitDeclRef =
299+
TypeExpr::createImplicit(boundRCAType, C);
300+
301+
auto initCallArgArgs = ArgumentList::forImplicitCallTo(
302+
DeclNameRef(remoteCallArgumentInitDecl->getEffectiveFullName()),
275303
{
276-
// name:
277-
new (C) StringLiteralExpr(argumentName, SourceRange(),
278-
/*implicit=*/true),
279-
// _ argument:
280-
new (C) DeclRefExpr(
304+
// label:
305+
argumentLabelArg,
306+
// name:
307+
new (C) StringLiteralExpr(parameterName, SourceRange(), implicit),
308+
// _ argument:
309+
new (C) DeclRefExpr(
281310
ConcreteDeclRef(param), dloc, implicit,
282311
AccessSemantics::Ordinary,
283312
thunk->mapTypeIntoContext(param->getInterfaceType()))
313+
},
314+
C);
315+
316+
auto initCallArgCallExpr =
317+
CallExpr::createImplicit(C, remoteCallArgumentInitDeclRef, initCallArgArgs);
318+
initCallArgCallExpr->setImplicit();
319+
320+
auto callArgPB = PatternBindingDecl::createImplicit(
321+
C, StaticSpellingKind::None, callArgPattern, initCallArgCallExpr, thunk);
322+
323+
remoteBranchStmts.push_back(callArgPB);
324+
remoteBranchStmts.push_back(callArgVar);
325+
326+
/// --- Pass the argumentRepr to the recordArgument function
327+
auto recordArgArgsList = ArgumentList::forImplicitCallTo(
328+
recordArgumentDeclRef->getName(),
329+
{
330+
new (C) DeclRefExpr(
331+
ConcreteDeclRef(callArgVar), dloc, implicit,
332+
AccessSemantics::Ordinary)
284333
}, C);
285334

286335
auto tryRecordArgExpr = TryExpr::createImplicit(C, sloc,

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
305305
decl->getDescriptiveKind(), decl->getName(), identifier);
306306
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
307307
decl->getName(), identifier,
308-
"mutating func recordArgument<Argument: SerializationRequirement>(name: String, _ argument: Argument) throws\n");
308+
"mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws\n");
309309
anyMissingAdHocRequirements = true;
310310
}
311311
if (checkAdHocRequirementAccessControl(decl, Proto, recordArgumentDecl))
@@ -731,6 +731,48 @@ GetDistributedRemoteCallTargetInitFunctionRequest::evaluate(
731731
return nullptr;
732732
}
733733

734+
ConstructorDecl*
735+
GetDistributedRemoteCallArgumentInitFunctionRequest::evaluate(
736+
Evaluator &evaluator,
737+
NominalTypeDecl *nominal) const {
738+
auto &C = nominal->getASTContext();
739+
740+
// not via `ensureDistributedModuleLoaded` to avoid generating a warning,
741+
// we won't be emitting the offending decl after all.
742+
if (!C.getLoadedModule(C.Id_Distributed))
743+
return nullptr;
744+
745+
if (!nominal->getDeclaredInterfaceType()->isEqual(
746+
C.getRemoteCallArgumentType()))
747+
return nullptr;
748+
749+
for (auto value : nominal->getMembers()) {
750+
auto ctor = dyn_cast<ConstructorDecl>(value);
751+
if (!ctor)
752+
continue;
753+
754+
auto params = ctor->getParameters();
755+
if (params->size() != 3)
756+
return nullptr;
757+
758+
// --- param: label
759+
if (!params->get(0)->getArgumentName().is("label"))
760+
return nullptr;
761+
762+
// --- param: name
763+
if (!params->get(1)->getArgumentName().is("name"))
764+
return nullptr;
765+
766+
// --- param: value
767+
if (params->get(2)->getArgumentName() != C.Id_value)
768+
return nullptr;
769+
770+
return ctor;
771+
}
772+
773+
return nullptr;
774+
}
775+
734776
NominalTypeDecl *
735777
GetDistributedActorInvocationDecoderRequest::evaluate(Evaluator &evaluator,
736778
NominalTypeDecl *actor) const {

stdlib/public/Distributed/DistributedActorSystem.swift

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ public protocol DistributedTargetInvocationEncoder {
420420
// ///
421421
// /// Record an argument of `Argument` type.
422422
// /// This will be invoked for every argument of the target, in declaration order.
423-
// mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws
423+
// mutating func recordArgument<Value: SerializationRequirement>(
424+
// _ argument: DistributedTargetArgument<Value>
425+
// ) throws
424426

425427
/// Record the error type of the distributed method.
426428
/// This method will not be invoked if the target is not throwing.
@@ -435,8 +437,47 @@ public protocol DistributedTargetInvocationEncoder {
435437
mutating func doneRecording() throws
436438
}
437439

440+
/// Represents an argument passed to a distributed call target.
438441
@available(SwiftStdlib 5.7, *)
439-
public
442+
public struct RemoteCallArgument<Value> {
443+
/// The "argument label" of the argument.
444+
/// The label is the name visible name used in external calls made to this
445+
/// target, e.g. for `func hello(label name: String)` it is `label`.
446+
///
447+
/// If no label is specified (i.e. `func hi(name: String)`), the `label`,
448+
/// value is empty, however `effectiveLabel` is equal to the `name`.
449+
///
450+
/// In most situations, using `effectiveLabel` is more useful to identify
451+
/// the user-visible name of this argument.
452+
public let label: String?
453+
454+
/// The effective label of this argument, i.e. if no explicit `label` was set
455+
/// this defaults to the `name`. This reflects the semantics of call sites of
456+
/// function declarations without explicit label definitions in Swift.
457+
public var effectiveLabel: String {
458+
return label ?? name
459+
}
460+
461+
/// The internal name of parameter this argument is accessible as in the
462+
/// function body. It is not part of the functions API and may change without
463+
/// breaking the target identifier.
464+
///
465+
/// If the method did not declare an explicit `label`, it is used as the
466+
/// `effectiveLabel`.
467+
public let name: String
468+
469+
/// The value of the argument being passed to the call.
470+
/// As `RemoteCallArgument` is always used in conjunction with
471+
/// `recordArgument` and populated by the compiler, this Value will generally
472+
/// conform to a distributed actor system's `SerializationRequirement`.
473+
public let value: Value
474+
475+
public init(label: String?, name: String, value: Value) {
476+
self.label = label
477+
self.name = name
478+
self.value = value
479+
}
480+
}
440481

441482
/// Decoder that must be provided to `executeDistributedTarget` and is used
442483
/// by the Swift runtime to decode arguments of the invocation.

stdlib/public/Distributed/LocalTestingDistributedActorSystem.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ public struct LocalTestingInvocationEncoder: DistributedTargetInvocationEncoder
151151
fatalError("Attempted to call encoder method in a local-only actor system")
152152
}
153153

154-
public mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {
154+
public mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws {
155155
fatalError("Attempted to call encoder method in a local-only actor system")
156156
}
157157

0 commit comments

Comments
 (0)