Skip to content

Commit e2c61d3

Browse files
committed
[Distributed] Move dist funcdecl getters to ASTContext
1 parent 6da455e commit e2c61d3

File tree

5 files changed

+94
-135
lines changed

5 files changed

+94
-135
lines changed

include/swift/AST/ASTContext.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -670,38 +670,38 @@ class ASTContext final {
670670
NominalTypeDecl *actorOrSystem,
671671
bool isVoidReturn) const;
672672

673+
/// Retrieve the declaration of DistributedActorSystem.make().
674+
///
675+
/// \param actorOrSystem distributed actor or actor system to get the
676+
/// remoteCall function for. Since the method we're looking for is an ad-hoc
677+
/// requirement, a specific type MUST be passed here as it is not possible
678+
/// to obtain the decl from just the `DistributedActorSystem` protocol type.
679+
FuncDecl *getMakeInvocationEncoderOnDistributedActorSystem(
680+
NominalTypeDecl *actorOrSystem) const;
681+
673682
// Retrieve the declaration of DistributedInvocationEncoder.recordArgument(_:).
674683
//
675684
// \param nominal optionally provide a 'NominalTypeDecl' from which the
676685
// function decl shall be extracted. This is useful to avoid witness calls
677686
// through the protocol which is looked up when nominal is null.
678687
FuncDecl *getRecordArgumentOnDistributedInvocationEncoder(
679-
NominalTypeDecl *nominal = nullptr) const;
688+
NominalTypeDecl *nominal) const;
680689

681-
// Retrieve the declaration of DistributedInvocationEncoder.recordErrorType().
682-
//
683-
// \param nominal optionally provide a 'NominalTypeDecl' from which the
684-
// function decl shall be extracted. This is useful to avoid witness calls
685-
// through the protocol which is looked up when nominal is null.
690+
// Retrieve the declaration of DistributedInvocationEncoder.recordErrorType(_:).
686691
FuncDecl *getRecordErrorTypeOnDistributedInvocationEncoder(
687-
NominalTypeDecl *nominal = nullptr) const;
692+
NominalTypeDecl *nominal) const;
688693

689-
// Retrieve the declaration of DistributedInvocationEncoder.recordReturnType().
690-
//
691-
// \param nominal optionally provide a 'NominalTypeDecl' from which the
692-
// function decl shall be extracted. This is useful to avoid witness calls
693-
// through the protocol which is looked up when nominal is null.
694+
// Retrieve the declaration of DistributedInvocationEncoder.recordReturnType(_:).
694695
FuncDecl *getRecordReturnTypeOnDistributedInvocationEncoder(
695-
NominalTypeDecl *nominal = nullptr) const;
696+
NominalTypeDecl *nominal) const;
696697

697698
// Retrieve the declaration of DistributedInvocationEncoder.doneRecording().
698699
//
699700
// \param nominal optionally provide a 'NominalTypeDecl' from which the
700701
// function decl shall be extracted. This is useful to avoid witness calls
701702
// through the protocol which is looked up when nominal is null.
702703
FuncDecl *getDoneRecordingOnDistributedInvocationEncoder(
703-
NominalTypeDecl *nominal = nullptr) const;
704-
704+
NominalTypeDecl *nominal) const;
705705

706706
/// Look for the declaration with the given name within the
707707
/// passed in module.

include/swift/AST/Decl.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3370,9 +3370,6 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
33703370
/// Find, or potentially synthesize, the implicit 'id' property of this actor.
33713371
VarDecl *getDistributedActorIDProperty() const;
33723372

3373-
/// Find the 'makeInvocation' function.
3374-
AbstractFunctionDecl* getDistributedActorSystemMakeInvocationEncoderFunction() const;
3375-
33763373
/// Find the 'RemoteCallTarget.init(_mangledName:)' initializer function
33773374
ConstructorDecl* getDistributedRemoteCallTargetInitFunction() const;
33783375

lib/AST/ASTContext.cpp

Lines changed: 49 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -267,18 +267,6 @@ struct ASTContext::Implementation {
267267
/// -> Builtin.Int1
268268
FuncDecl *IsOSVersionAtLeastDecl = nullptr;
269269

270-
/// func recordArgument(_:) throws
271-
FuncDecl *RecordArgumentDistributedInvocationEncoderDecl = nullptr;
272-
273-
/// func recordErrorType(_:) throws
274-
FuncDecl *RecordErrorTypeDistributedInvocationEncoderDecl = nullptr;
275-
276-
/// func recordReturnType(_:) throws
277-
FuncDecl *RecordReturnTypeDistributedInvocationEncoderDecl = nullptr;
278-
279-
/// func doneRecording() throws
280-
FuncDecl *DoneRecordingDistributedInvocationEncoderDecl = nullptr;
281-
282270
/// The set of known protocols, lazily populated as needed.
283271
ProtocolDecl *KnownProtocols[NumKnownProtocols] = { };
284272

@@ -1295,7 +1283,7 @@ AbstractFunctionDecl *ASTContext::getRemoteCallOnDistributedActorSystem(
12951283
NominalTypeDecl *actorOrSystem, bool isVoidReturn) const {
12961284
assert(actorOrSystem && "distributed actor (or system) decl must be provided");
12971285
const NominalTypeDecl *system = actorOrSystem;
1298-
if (actorOrSystem && actorOrSystem->isDistributedActor()) {
1286+
if (actorOrSystem->isDistributedActor()) {
12991287
auto var = actorOrSystem->getDistributedActorSystemProperty();
13001288
system = var->getInterfaceType()->getAnyNominal();
13011289
}
@@ -1310,113 +1298,99 @@ AbstractFunctionDecl *ASTContext::getRemoteCallOnDistributedActorSystem(
13101298
nullptr);
13111299
}
13121300

1313-
FuncDecl *ASTContext::getRecordArgumentOnDistributedInvocationEncoder(
1314-
NominalTypeDecl *nominal) const {
1315-
if (getImpl().RecordArgumentDistributedInvocationEncoderDecl) {
1316-
return getImpl().RecordArgumentDistributedInvocationEncoderDecl;
1301+
FuncDecl *ASTContext::getMakeInvocationEncoderOnDistributedActorSystem(
1302+
NominalTypeDecl *actorOrSystem) const {
1303+
NominalTypeDecl *system = actorOrSystem;
1304+
assert(actorOrSystem && "distributed actor (or system) decl must be provided");
1305+
if (actorOrSystem->isDistributedActor()) {
1306+
auto var = actorOrSystem->getDistributedActorSystemProperty();
1307+
system = var->getInterfaceType()->getAnyNominal();
13171308
}
13181309

1319-
NominalTypeDecl *encoderProto = nominal ?
1320-
nominal :
1321-
getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
1322-
assert(encoderProto && "Missing DistributedTargetInvocationEncoder protocol");
1323-
for (auto result : encoderProto->lookupDirect(Id_recordArgument)) {
1310+
for (auto result : system->lookupDirect(Id_makeInvocationEncoder)) {
13241311
auto *fd = dyn_cast<FuncDecl>(result);
13251312
if (!fd)
13261313
continue;
1314+
if (fd->getParameters()->size() != 0)
1315+
continue;
1316+
if (fd->hasAsync())
1317+
continue;
1318+
if (fd->hasThrows())
1319+
continue;
1320+
// TODO(distributed): more checks, return type etc
13271321

1322+
return fd;
1323+
}
1324+
1325+
return nullptr;
1326+
}
1327+
1328+
FuncDecl *ASTContext::getRecordArgumentOnDistributedInvocationEncoder(
1329+
NominalTypeDecl *nominal) const {
1330+
for (auto result : nominal->lookupDirect(Id_recordArgument)) {
1331+
auto *fd = dyn_cast<FuncDecl>(result);
1332+
if (!fd)
1333+
continue;
13281334
if (fd->getParameters()->size() != 1)
13291335
continue;
1330-
1336+
if (fd->hasAsync())
1337+
continue;
1338+
if (!fd->hasThrows())
1339+
continue;
13311340
// TODO(distributed): more checks
13321341

1333-
if (fd->getResultInterfaceType()->isVoid() &&
1334-
fd->hasThrows() &&
1335-
!fd->hasAsync()) {
1336-
getImpl().RecordArgumentDistributedInvocationEncoderDecl = fd;
1342+
if (fd->getResultInterfaceType()->isVoid())
13371343
return fd;
1338-
}
13391344
}
13401345

13411346
return nullptr;
13421347
}
13431348

13441349
FuncDecl *ASTContext::getRecordErrorTypeOnDistributedInvocationEncoder(
13451350
NominalTypeDecl *nominal) const {
1346-
if (getImpl().RecordErrorTypeDistributedInvocationEncoderDecl) {
1347-
return getImpl().RecordErrorTypeDistributedInvocationEncoderDecl;
1348-
}
1349-
1350-
NominalTypeDecl *encoderProto =
1351-
nominal
1352-
? nominal
1353-
: getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
1354-
assert(encoderProto && "Missing DistributedTargetInvocationEncoder protocol");
1355-
for (auto result : encoderProto->lookupDirect(Id_recordErrorType)) {
1351+
for (auto result : nominal->lookupDirect(Id_recordErrorType)) {
13561352
auto *fd = dyn_cast<FuncDecl>(result);
13571353
if (!fd)
13581354
continue;
1359-
13601355
if (fd->getParameters()->size() != 1)
13611356
continue;
1357+
if (fd->hasAsync())
1358+
continue;
1359+
if (!fd->hasThrows())
1360+
continue;
1361+
// TODO(distributed): more checks
13621362

1363-
// TODO(distributed): more checks that the arg type matches (!!!)
1364-
1365-
if (fd->getResultInterfaceType()->isVoid() &&
1366-
fd->hasThrows() &&
1367-
!fd->hasAsync()) {
1368-
getImpl().RecordErrorTypeDistributedInvocationEncoderDecl = fd;
1363+
if (fd->getResultInterfaceType()->isVoid())
13691364
return fd;
1370-
}
13711365
}
13721366

13731367
return nullptr;
13741368
}
13751369

13761370
FuncDecl *ASTContext::getRecordReturnTypeOnDistributedInvocationEncoder(
13771371
NominalTypeDecl *nominal) const {
1378-
if (getImpl().RecordReturnTypeDistributedInvocationEncoderDecl) {
1379-
return getImpl().RecordReturnTypeDistributedInvocationEncoderDecl;
1380-
}
1381-
1382-
NominalTypeDecl *encoderProto =
1383-
nominal
1384-
? nominal
1385-
: getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
1386-
assert(encoderProto && "Missing DistributedTargetInvocationEncoder protocol");
1387-
for (auto result : encoderProto->lookupDirect(Id_recordReturnType)) {
1372+
for (auto result : nominal->lookupDirect(Id_recordReturnType)) {
13881373
auto *fd = dyn_cast<FuncDecl>(result);
13891374
if (!fd)
13901375
continue;
1391-
13921376
if (fd->getParameters()->size() != 1)
13931377
continue;
1378+
if (fd->hasAsync())
1379+
continue;
1380+
if (!fd->hasThrows())
1381+
continue;
1382+
// TODO(distributed): more checks
13941383

1395-
// TODO(distributed): more checks that the arg type matches (!!!)
1396-
1397-
if (fd->getResultInterfaceType()->isVoid() &&
1398-
fd->hasThrows() &&
1399-
!fd->hasAsync()) {
1400-
getImpl().RecordReturnTypeDistributedInvocationEncoderDecl = fd;
1384+
if (fd->getResultInterfaceType()->isVoid())
14011385
return fd;
1402-
}
14031386
}
14041387

14051388
return nullptr;
14061389
}
14071390

14081391
FuncDecl *ASTContext::getDoneRecordingOnDistributedInvocationEncoder(
14091392
NominalTypeDecl *nominal) const {
1410-
if (getImpl().DoneRecordingDistributedInvocationEncoderDecl) {
1411-
return getImpl().DoneRecordingDistributedInvocationEncoderDecl;
1412-
}
1413-
1414-
NominalTypeDecl *encoderProto =
1415-
nominal
1416-
? nominal
1417-
: getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
1418-
assert(encoderProto && "Missing DistributedTargetInvocationEncoder protocol");
1419-
for (auto result : encoderProto->lookupDirect(Id_doneRecording)) {
1393+
for (auto result : nominal->lookupDirect(Id_doneRecording)) {
14201394
auto *fd = dyn_cast<FuncDecl>(result);
14211395
if (!fd)
14221396
continue;
@@ -1426,10 +1400,8 @@ FuncDecl *ASTContext::getDoneRecordingOnDistributedInvocationEncoder(
14261400

14271401
if (fd->getResultInterfaceType()->isVoid() &&
14281402
fd->hasThrows() &&
1429-
!fd->hasAsync()) {
1430-
getImpl().DoneRecordingDistributedInvocationEncoderDecl = fd;
1403+
!fd->hasAsync())
14311404
return fd;
1432-
}
14331405
}
14341406

14351407
return nullptr;

lib/AST/Decl.cpp

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7402,17 +7402,14 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
74027402
auto callId = isVoidReturn ? C.Id_remoteCallVoid : C.Id_remoteCall;
74037403

74047404
// Check the name
7405-
if (this->getBaseName() != callId)
7405+
if (getBaseName() != callId)
74067406
return false;
74077407

7408-
auto params = this->getParameters();
7408+
auto params = getParameters();
7409+
unsigned int expectedParamNum = isVoidReturn ? 4 : 5;
74097410

7410-
// Check the expected argument count
7411-
// - for value returning remoteCall:
7412-
if (!params || (!isVoidReturn && params->size() != 5))
7413-
return false;
7414-
// - for void returning remoteCallVoid:
7415-
if (!params || (isVoidReturn && params->size() != 4))
7411+
// Check the expected argument count:
7412+
if (!params || params->size() != expectedParamNum)
74167413
return false;
74177414

74187415
// Check API names of the arguments
@@ -7422,7 +7419,7 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
74227419
auto thrownTypeParam = params->get(3);
74237420
if (actorParam->getArgumentName() != C.Id_on ||
74247421
targetParam->getArgumentName() != C.Id_target ||
7425-
invocationParam->getArgumentName() != C.Id_invocationDecoder ||
7422+
invocationParam->getArgumentName() != C.Id_invocation ||
74267423
thrownTypeParam->getArgumentName() != C.Id_throwing)
74277424
return false;
74287425

@@ -7432,41 +7429,34 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
74327429
return false;
74337430
}
74347431

7435-
// FIXME(distributed): check the right types of the args and generics...
7436-
// FIXME(distributed): check access level actually is ok, i.e. not private etc
7437-
7438-
return true;
7439-
}
7432+
if (!isGeneric())
7433+
return false;
74407434

7441-
bool AbstractFunctionDecl::isDistributed() const {
7442-
return this->getAttrs().hasAttribute<DistributedActorAttr>();
7443-
}
7435+
auto genericParams = getGenericParams();
7436+
unsigned int expectedGenericParamNum = isVoidReturn ? 2 : 3;
74447437

7445-
AbstractFunctionDecl*
7446-
NominalTypeDecl::getDistributedActorSystemMakeInvocationEncoderFunction() const {
7447-
auto &C = this->getASTContext();
7448-
NominalTypeDecl *system = const_cast<NominalTypeDecl *>(this);
7449-
if (this->isDistributedActor()) {
7450-
auto var = this->getDistributedActorSystemProperty();
7451-
system = var->getInterfaceType()->getAnyNominal();
7438+
// We expect: Act, Err, Res?
7439+
if (genericParams->size() != expectedGenericParamNum) {
7440+
return false;
74527441
}
74537442

7454-
// FIXME(distributed): implement more properly...
7455-
for (auto value : system->lookupDirect(C.Id_makeInvocationEncoder)) {
7456-
auto func = dyn_cast<AbstractFunctionDecl>(value);
7457-
if (!func)
7458-
continue;
7443+
// FIXME(distributed): check the exact generic requirements
74597444

7460-
if (func->getParameters()->size() != 0)
7461-
continue;
7445+
// === check the return type
7446+
if (isVoidReturn) {
7447+
if (auto func = dyn_cast<FuncDecl>(this))
7448+
if (!func->getResultInterfaceType()->isVoid())
7449+
return false;
7450+
}
74627451

7463-
// TODO(distriuted): return type must conform to our expected protocol
7452+
// FIXME(distributed): check the right types of the args and generics...
7453+
// FIXME(distributed): check access level actually is ok, i.e. not private etc
74647454

7465-
return func;
7466-
}
7455+
return true;
7456+
}
74677457

7468-
// TODO(distributed): make a Request for it?
7469-
return nullptr;
7458+
bool AbstractFunctionDecl::isDistributed() const {
7459+
return getAttrs().hasAttribute<DistributedActorAttr>();
74707460
}
74717461

74727462
ConstructorDecl*

lib/SILGen/SILGenDistributed.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,8 @@ void SILGenFunction::emitDistributedThunk(SILDeclRef thunk) {
724724

725725
// === `InvocationEncoder` types
726726
AbstractFunctionDecl *makeInvocationEncoderFnDecl =
727-
selfTyDecl->getDistributedActorSystemMakeInvocationEncoderFunction();
728-
assert(makeInvocationEncoderFnDecl && "no remoteCall func found!");
727+
ctx.getMakeInvocationEncoderOnDistributedActorSystem(selfTyDecl);
728+
assert(makeInvocationEncoderFnDecl && "no 'makeInvocationEncoder' func found!");
729729
auto makeInvocationEncoderFnRef = SILDeclRef(makeInvocationEncoderFnDecl);
730730

731731
ProtocolDecl *invocationEncoderProto =
@@ -882,7 +882,7 @@ void SILGenFunction::emitDistributedThunk(SILDeclRef thunk) {
882882

883883
// function_ref FakeActorSystem.makeInvocationEncoder()
884884
// %19 = function_ref @$s27FakeDistributedActorSystems0aC6SystemV21makeInvocationEncoderAA0aG0VyF : $@convention(method) (@guaranteed FakeActorSystem) -> FakeInvocation // user: %20
885-
auto makeInvocationEncoderFnSIL =
885+
SILFunction *makeInvocationEncoderFnSIL =
886886
builder.getOrCreateFunction(loc, makeInvocationEncoderFnRef, NotForDefinition);
887887
SILValue makeInvocationEncoderFn =
888888
B.createFunctionRefFor(loc, makeInvocationEncoderFnSIL);
@@ -1395,7 +1395,7 @@ void SILGenFunction::emitDistributedThunk(SILDeclRef thunk) {
13951395
}
13961396
assert(returnMetatypeValue);
13971397

1398-
// function_ref FakeActorSystem.remoteCall<A, B, C>(on:target:invocationDecoder:throwing:returning:)
1398+
// function_ref FakeActorSystem.remoteCall<A, B, C>(on:target:invocation:throwing:returning:)
13991399
// %49 = function_ref @$s27FakeDistributedActorSystems0aC6SystemV10remoteCall2on6target17invocationDecoder8throwing9returningq0_x_01_B006RemoteG6TargetVAA0A10InvocationVzq_mq0_mSgtYaKAJ0bC0RzSeR0_SER0_AA0C7AddressV2IDRtzr1_lF : $@convention(method) @async <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : DistributedActor, τ_0_2 : Decodable, τ_0_2 : Encodable, τ_0_0.ID == ActorAddress> (@guaranteed τ_0_0, @in_guaranteed RemoteCallTarget, @inout FakeInvocation, @thick τ_0_1.Type, Optional<@thick τ_0_2.Type>, @guaranteed FakeActorSystem) -> (@out τ_0_2, @error Error) // user: %50
14001400
auto remoteCallFnDecl =
14011401
ctx.getRemoteCallOnDistributedActorSystem(selfTyDecl, /*isVoid=*/resultType.isVoid());

0 commit comments

Comments
 (0)