Skip to content

[Distributed] Add name parameter to recordArgument for better interop #41799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3493,9 +3493,12 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
/// Find, or potentially synthesize, the implicit 'id' property of this actor.
VarDecl *getDistributedActorIDProperty() const;

/// Find the 'RemoteCallTarget.init(_mangledName:)' initializer function
/// Find the 'RemoteCallTarget.init(_:)' initializer function
ConstructorDecl* getDistributedRemoteCallTargetInitFunction() const;

/// Find the 'RemoteCallArgument(label:name:value:)' initializer function
ConstructorDecl* getDistributedRemoteCallArgumentInitFunction() const;

/// Collect the set of protocols to which this type should implicitly
/// conform, such as AnyObject (for classes).
void getImplicitProtocols(SmallVectorImpl<ProtocolDecl *> &protocols);
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/KnownSDKTypes.def
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ KNOWN_SDK_TYPE_DECL(Distributed, DistributedActorSystem, ProtocolDecl, 0)
KNOWN_SDK_TYPE_DECL(Distributed, DistributedTargetInvocationEncoder, ProtocolDecl, 0)
KNOWN_SDK_TYPE_DECL(Distributed, DistributedTargetInvocationDecoder, ProtocolDecl, 0)
KNOWN_SDK_TYPE_DECL(Distributed, RemoteCallTarget, StructDecl, 0)
KNOWN_SDK_TYPE_DECL(Distributed, RemoteCallArgument, StructDecl, 1)

// String processing
KNOWN_SDK_TYPE_DECL(StringProcessing, Regex, StructDecl, 1)
Expand Down
21 changes: 20 additions & 1 deletion include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ class GetDistributedActorSystemPropertyRequest :
bool isCached() const { return true; }
};

/// Obtain the constructor of the RemoteCallTarget type.
/// Obtain the constructor of the 'RemoteCallTarget' type.
class GetDistributedRemoteCallTargetInitFunctionRequest :
public SimpleRequest<GetDistributedRemoteCallTargetInitFunctionRequest,
ConstructorDecl *(NominalTypeDecl *),
Expand All @@ -1229,6 +1229,25 @@ class GetDistributedRemoteCallTargetInitFunctionRequest :
bool isCached() const { return true; }
};

/// Obtain the constructor of the 'RemoteCallArgument' type.
class GetDistributedRemoteCallArgumentInitFunctionRequest :
public SimpleRequest<GetDistributedRemoteCallArgumentInitFunctionRequest,
ConstructorDecl *(NominalTypeDecl *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

ConstructorDecl *evaluate(Evaluator &evaluator,
NominalTypeDecl *nominal) const;

public:
// Caching
bool isCached() const { return true; }
};

/// Obtain the 'distributed thunk' for the passed-in function.
///
/// The thunk is responsible for invoking 'remoteCall' when invoked on a remote
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ SWIFT_REQUEST(TypeChecker, GetDistributedActorSystemPropertyRequest,
SWIFT_REQUEST(TypeChecker, GetDistributedRemoteCallTargetInitFunctionRequest,
ConstructorDecl *(NominalTypeDecl *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, GetDistributedRemoteCallArgumentInitFunctionRequest,
ConstructorDecl *(NominalTypeDecl *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, GetDistributedActorInvocationDecoderRequest,
NominalTypeDecl *(NominalTypeDecl *),
Cached, NoLocationInfo)
Expand Down
105 changes: 64 additions & 41 deletions lib/AST/DistributedDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,11 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
auto &C = getASTContext();
auto module = getParentModule();

auto func = dyn_cast<FuncDecl>(this);
if (!func) {
return false;
}

// === Check base name
if (getBaseIdentifier() != C.Id_recordArgument) {
return false;
Expand Down Expand Up @@ -614,6 +619,12 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
return false;
}

// --- must be mutating, if it is defined in a struct
if (isa<StructDecl>(getDeclContext()) &&
!func->isMutating()) {
return false;
}

// --- Check number of generic parameters
auto genericParams = getGenericParams();
unsigned int expectedGenericParamNum = 1;
Expand All @@ -639,56 +650,60 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
return false;
}

// --- Check parameter: _ argument
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().is("")) {
return false;
}

// === Check generic parameters in detail
// --- Check: Argument: SerializationRequirement
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];

auto sig = getGenericSignature();
auto requirements = sig.getRequirements();
// --- Check parameter: _ argument
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().empty()) {
return false;
}

if (requirements.size() != expectedRequirementsNum) {
return false;
}
auto argumentTy = argumentParam->getInterfaceType();
auto argumentInContextTy = mapTypeIntoContext(argumentTy);
if (argumentInContextTy->getAnyNominal() == C.getRemoteCallArgumentDecl()) {
auto argGenericParams = argumentInContextTy->getStructOrBoundGenericStruct()
->getGenericParams()->getParams();
if (argGenericParams.size() != 1) {
return false;
}

// --- Check the expected requirements
// --- all the Argument requirements ---
// conforms_to: Argument Decodable
// conforms_to: Argument Encodable
// ...
// the <Value> of the RemoteCallArgument<Value>
auto remoteCallArgValueGenericTy =
mapTypeIntoContext(argGenericParams[0]->getInterfaceType())
->getDesugaredType()
->getMetatypeInstanceType();
// expected (the <Value> from the recordArgument<Value>)
auto expectedGenericParamTy = mapTypeIntoContext(
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());

if (!remoteCallArgValueGenericTy->isEqual(expectedGenericParamTy)) {
return false;
}
} else {
return false;
}

auto func = dyn_cast<FuncDecl>(this);
if (!func) {
return false;
}

auto resultType = func->mapTypeIntoContext(argumentParam->getInterfaceType())
->getDesugaredType();
auto resultParamType = func->mapTypeIntoContext(
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
// The result of the function must be the `Res` generic argument.
if (!resultType->isEqual(resultParamType)) {
return false;
}
auto sig = getGenericSignature();
auto requirements = sig.getRequirements();

for (auto requirementProto : requirementProtos) {
auto conformance = module->lookupConformance(resultType, requirementProto);
if (conformance.isInvalid()) {
if (requirements.size() != expectedRequirementsNum) {
return false;
}
}

// === Check result type: Void
if (!func->getResultInterfaceType()->isVoid()) {
return false;
}
// --- Check the expected requirements
// --- all the Argument requirements ---
// e.g.
// conforms_to: Argument Decodable
// conforms_to: Argument Encodable
// ...

return true;
// === Check result type: Void
if (!func->getResultInterfaceType()->isVoid()) {
return false;
}

return true;
}

bool
Expand Down Expand Up @@ -879,8 +894,8 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
}

// --- Check parameter: _ errorType
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().is("")) {
auto errorTypeParam = params->get(0);
if (!errorTypeParam->getArgumentName().is("")) {
return false;
}

Expand Down Expand Up @@ -1140,6 +1155,14 @@ NominalTypeDecl::getDistributedRemoteCallTargetInitFunction() const {
GetDistributedRemoteCallTargetInitFunctionRequest(mutableThis), nullptr);
}

ConstructorDecl *
NominalTypeDecl::getDistributedRemoteCallArgumentInitFunction() const {
auto mutableThis = const_cast<NominalTypeDecl *>(this);
return evaluateOrDefault(
getASTContext().evaluator,
GetDistributedRemoteCallArgumentInitFunctionRequest(mutableThis), nullptr);
}

AbstractFunctionDecl *ASTContext::getRemoteCallOnDistributedActorSystem(
NominalTypeDecl *actorOrSystem, bool isVoidReturn) const {
assert(actorOrSystem && "distributed actor (or system) decl must be provided");
Expand Down
60 changes: 57 additions & 3 deletions lib/Sema/CodeSynthesisDistributedActor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,67 @@ deriveBodyDistributed_thunk(AbstractFunctionDecl *thunk, void *context) {
auto recordArgumentDeclRef = UnresolvedDeclRefExpr::createImplicit(
C, recordArgumentDecl->getName());

auto recordArgArgsList = ArgumentList::forImplicitCallTo(
recordArgumentDeclRef->getName(),
auto argumentName = param->getArgumentName().str();
LiteralExpr *argumentLabelArg;
if (argumentName.empty()) {
argumentLabelArg = new (C) NilLiteralExpr(sloc, implicit);
} else {
argumentLabelArg =
new (C) StringLiteralExpr(argumentName, SourceRange(), implicit);
}
auto parameterName = param->getParameterName().str();


// --- Prepare the RemoteCallArgument<Value> for the argument
auto argumentVarName = C.getIdentifier("_" + parameterName.str());
StructDecl *RCA = C.getRemoteCallArgumentDecl();
VarDecl *callArgVar =
new (C) VarDecl(/*isStatic=*/false, VarDecl::Introducer::Let, sloc,
argumentVarName, thunk);
callArgVar->setImplicit();
callArgVar->setSynthesized();

Pattern *callArgPattern = NamedPattern::createImplicit(C, callArgVar);

auto remoteCallArgumentInitDecl =
RCA->getDistributedRemoteCallArgumentInitFunction();
auto boundRCAType = BoundGenericType::get(
RCA, Type(), {thunk->mapTypeIntoContext(param->getInterfaceType())});
auto remoteCallArgumentInitDeclRef =
TypeExpr::createImplicit(boundRCAType, C);

auto initCallArgArgs = ArgumentList::forImplicitCallTo(
DeclNameRef(remoteCallArgumentInitDecl->getEffectiveFullName()),
{
new (C) DeclRefExpr(
// label:
argumentLabelArg,
// name:
new (C) StringLiteralExpr(parameterName, SourceRange(), implicit),
// _ argument:
new (C) DeclRefExpr(
ConcreteDeclRef(param), dloc, implicit,
AccessSemantics::Ordinary,
thunk->mapTypeIntoContext(param->getInterfaceType()))
},
C);

auto initCallArgCallExpr =
CallExpr::createImplicit(C, remoteCallArgumentInitDeclRef, initCallArgArgs);
initCallArgCallExpr->setImplicit();

auto callArgPB = PatternBindingDecl::createImplicit(
C, StaticSpellingKind::None, callArgPattern, initCallArgCallExpr, thunk);

remoteBranchStmts.push_back(callArgPB);
remoteBranchStmts.push_back(callArgVar);

/// --- Pass the argumentRepr to the recordArgument function
auto recordArgArgsList = ArgumentList::forImplicitCallTo(
recordArgumentDeclRef->getName(),
{
new (C) DeclRefExpr(
ConcreteDeclRef(callArgVar), dloc, implicit,
AccessSemantics::Ordinary)
}, C);

auto tryRecordArgExpr = TryExpr::createImplicit(C, sloc,
Expand Down
44 changes: 43 additions & 1 deletion lib/Sema/TypeCheckDistributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
decl->getDescriptiveKind(), decl->getName(), identifier);
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
decl->getName(), identifier,
"mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws\n");
"mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws\n");
anyMissingAdHocRequirements = true;
}
if (checkAdHocRequirementAccessControl(decl, Proto, recordArgumentDecl))
Expand Down Expand Up @@ -731,6 +731,48 @@ GetDistributedRemoteCallTargetInitFunctionRequest::evaluate(
return nullptr;
}

ConstructorDecl*
GetDistributedRemoteCallArgumentInitFunctionRequest::evaluate(
Evaluator &evaluator,
NominalTypeDecl *nominal) const {
auto &C = nominal->getASTContext();

// not via `ensureDistributedModuleLoaded` to avoid generating a warning,
// we won't be emitting the offending decl after all.
if (!C.getLoadedModule(C.Id_Distributed))
return nullptr;

if (!nominal->getDeclaredInterfaceType()->isEqual(
C.getRemoteCallArgumentType()))
return nullptr;

for (auto value : nominal->getMembers()) {
auto ctor = dyn_cast<ConstructorDecl>(value);
if (!ctor)
continue;

auto params = ctor->getParameters();
if (params->size() != 3)
return nullptr;

// --- param: label
if (!params->get(0)->getArgumentName().is("label"))
return nullptr;

// --- param: name
if (!params->get(1)->getArgumentName().is("name"))
return nullptr;

// --- param: value
if (params->get(2)->getArgumentName() != C.Id_value)
return nullptr;

return ctor;
}

return nullptr;
}

NominalTypeDecl *
GetDistributedActorInvocationDecoderRequest::evaluate(Evaluator &evaluator,
NominalTypeDecl *actor) const {
Expand Down
Loading