Skip to content

Commit f0add8b

Browse files
committed
[AST/Sema] Distributed: Introduced unified way to retrieve serialization requirements for actors
1 parent 072a43e commit f0add8b

File tree

6 files changed

+116
-141
lines changed

6 files changed

+116
-141
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class DeclContext;
3131
class FuncDecl;
3232
class NominalTypeDecl;
3333

34-
Type getAssociatedTypeOfDistributedSystemOfActor(NominalTypeDecl *actor,
34+
Type getAssociatedTypeOfDistributedSystemOfActor(DeclContext *actorOrExtension,
3535
Identifier member);
3636

3737
/// Find the concrete invocation decoder associated with the given actor.
@@ -57,12 +57,6 @@ Type getDistributedActorSystemType(NominalTypeDecl *actor);
5757
/// Determine the `ID` type for the given actor.
5858
Type getDistributedActorIDType(NominalTypeDecl *actor);
5959

60-
/// Similar to `getDistributedSerializationRequirementType`, however, from the
61-
/// perspective of a concrete function. This way we're able to get the
62-
/// serialization requirement for specific members, also in protocols.
63-
Type getSerializationRequirementTypesForMember(
64-
ValueDecl *member, llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements);
65-
6660
/// Get specific 'SerializationRequirement' as defined in 'nominal'
6761
/// type, which must conform to the passed 'protocol' which is expected
6862
/// to require the 'SerializationRequirement'.
@@ -76,6 +70,12 @@ AbstractFunctionDecl *
7670
getAssociatedDistributedInvocationDecoderDecodeNextArgumentFunction(
7771
ValueDecl *thunk);
7872

73+
Type getDistributedActorSerializationType(DeclContext *actorOrExtension);
74+
75+
/// Get the specific 'SerializationRequirement' type of a specific distributed
76+
/// actor system.
77+
Type getDistributedActorSystemSerializationType(NominalTypeDecl *system);
78+
7979
/// Get the specific 'InvocationEncoder' type of a specific distributed actor
8080
/// system.
8181
Type getDistributedActorSystemInvocationEncoderType(NominalTypeDecl *system);
@@ -91,17 +91,6 @@ Type getDistributedActorSystemResultHandlerType(NominalTypeDecl *system);
9191
/// Get the 'ActorID' type of a specific distributed actor system.
9292
Type getDistributedActorSystemActorIDType(NominalTypeDecl *system);
9393

94-
/// Get the specific protocols that the `SerializationRequirement` specifies,
95-
/// and all parameters / return types of distributed targets must conform to.
96-
///
97-
/// E.g. if a system declares `typealias SerializationRequirement = Codable`
98-
/// then this will return `{encodableProtocol, decodableProtocol}`.
99-
///
100-
/// Returns an empty set if the requirement was `Any`.
101-
llvm::SmallPtrSet<ProtocolDecl *, 2>
102-
getDistributedSerializationRequirementProtocols(
103-
NominalTypeDecl *decl, ProtocolDecl* protocol);
104-
10594
/// Check if the `allRequirements` represent *exactly* the
10695
/// `Encodable & Decodable` (also known as `Codable`) requirement.
10796
///

lib/AST/DistributedDecl.cpp

Lines changed: 68 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -106,42 +106,6 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(
106106
llvm_unreachable("Unable to fetch ActorSystem type!");
107107
}
108108

109-
Type swift::getSerializationRequirementTypesForMember(
110-
ValueDecl *member,
111-
llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements) {
112-
auto &C = member->getASTContext();
113-
auto *DC = member->getDeclContext();
114-
auto DA = C.getDistributedActorDecl();
115-
116-
// === When declared inside an actor, we can get the type directly
117-
if (auto classDecl = DC->getSelfClassDecl()) {
118-
return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl());
119-
}
120-
121-
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
122-
->getDeclaredInterfaceType();
123-
124-
if (DC->getSelfProtocolDecl()) {
125-
GenericSignature signature;
126-
if (auto *genericContext = member->getAsGenericContext()) {
127-
signature = genericContext->getGenericSignature();
128-
} else {
129-
signature = DC->getGenericSignatureOfContext();
130-
}
131-
132-
// Also store all `SerializationRequirement : SomeProtocol` requirements
133-
for (auto proto: signature->getRequiredProtocols(SerReqAssocType)) {
134-
serializationRequirements.insert(proto);
135-
}
136-
137-
// Note that this may be null, e.g. if we're a distributed func inside
138-
// a protocol that did not declare a specific actor system requirement.
139-
return signature->getConcreteType(SerReqAssocType);
140-
}
141-
142-
llvm_unreachable("Unable to fetch SerializationRequirement type!");
143-
}
144-
145109
Type swift::getDistributedActorSystemType(NominalTypeDecl *actor) {
146110
assert(!dyn_cast<ProtocolDecl>(actor) &&
147111
"Use getConcreteReplacementForProtocolActorSystemType instead to get"
@@ -179,6 +143,53 @@ static Type getTypeWitnessByName(NominalTypeDecl *type, ProtocolDecl *protocol,
179143
return conformance.getTypeWitnessByName(selfType, member);
180144
}
181145

146+
Type swift::getDistributedActorSerializationType(
147+
DeclContext *actorOrExtension) {
148+
auto &ctx = actorOrExtension->getASTContext();
149+
auto resultTy = getAssociatedTypeOfDistributedSystemOfActor(
150+
actorOrExtension,
151+
ctx.Id_SerializationRequirement);
152+
153+
// Protocols are allowed to either not provide a `SerializationRequirement`
154+
// at all or provide it in a conformance requirement.
155+
if ((!resultTy || resultTy->hasDependentMember()) &&
156+
actorOrExtension->getSelfProtocolDecl()) {
157+
auto sig = actorOrExtension->getGenericSignatureOfContext();
158+
159+
auto actorProtocol = ctx.getProtocol(KnownProtocolKind::DistributedActor);
160+
if (!actorProtocol)
161+
return Type();
162+
163+
auto serializationTy =
164+
actorProtocol->getAssociatedType(ctx.Id_SerializationRequirement)
165+
->getDeclaredInterfaceType();
166+
167+
auto protocols = sig->getRequiredProtocols(serializationTy);
168+
if (protocols.empty())
169+
return Type();
170+
171+
SmallVector<Type, 2> members;
172+
llvm::transform(protocols, std::back_inserter(members), [](const auto *P) {
173+
return P->getDeclaredInterfaceType();
174+
});
175+
176+
return ExistentialType::get(
177+
ProtocolCompositionType::get(ctx, members,
178+
/*inverses=*/{},
179+
/*HasExplicitAnyObject=*/false));
180+
}
181+
182+
return resultTy;
183+
}
184+
185+
Type swift::getDistributedActorSystemSerializationType(
186+
NominalTypeDecl *system) {
187+
assert(!system->isDistributedActor());
188+
auto &ctx = system->getASTContext();
189+
return getTypeWitnessByName(system, ctx.getDistributedActorSystemDecl(),
190+
ctx.Id_SerializationRequirement);
191+
}
192+
182193
Type swift::getDistributedActorSystemActorIDType(NominalTypeDecl *system) {
183194
assert(!system->isDistributedActor());
184195
auto &ctx = system->getASTContext();
@@ -248,19 +259,14 @@ swift::getAssociatedDistributedInvocationDecoderDecodeNextArgumentFunction(
248259
decoderTy->getAnyNominal());
249260
}
250261

251-
Type swift::getAssociatedTypeOfDistributedSystemOfActor(NominalTypeDecl *actor,
252-
Identifier member) {
253-
auto &ctx = actor->getASTContext();
262+
Type swift::getAssociatedTypeOfDistributedSystemOfActor(
263+
DeclContext *actorOrExtension, Identifier member) {
264+
auto &ctx = actorOrExtension->getASTContext();
254265

255266
auto actorProtocol = ctx.getProtocol(KnownProtocolKind::DistributedActor);
256267
if (!actorProtocol)
257268
return Type();
258269

259-
auto actorConformance = actor->getParentModule()->lookupConformance(
260-
actor->getDeclaredInterfaceType(), actorProtocol);
261-
if (!actorConformance || actorConformance.isInvalid())
262-
return Type();
263-
264270
AssociatedTypeDecl *actorSystemDecl =
265271
actorProtocol->getAssociatedType(ctx.Id_ActorSystem);
266272
if (!actorSystemDecl)
@@ -275,15 +281,27 @@ Type swift::getAssociatedTypeOfDistributedSystemOfActor(NominalTypeDecl *actor,
275281
if (!memberTypeDecl)
276282
return Type();
277283

278-
auto depMemTy = DependentMemberType::get(
279-
DependentMemberType::get(actorProtocol->getSelfInterfaceType(),
280-
actorSystemDecl),
281-
memberTypeDecl);
284+
Type memberTy = DependentMemberType::get(
285+
DependentMemberType::get(actorProtocol->getSelfInterfaceType(),
286+
actorSystemDecl),
287+
memberTypeDecl);
288+
289+
auto sig = actorOrExtension->getGenericSignatureOfContext();
290+
291+
auto *actorType = actorOrExtension->getSelfNominalTypeDecl();
292+
if (isa<ProtocolDecl>(actorType))
293+
return memberTy->getReducedType(sig);
294+
295+
auto actorConformance =
296+
actorOrExtension->getParentModule()->lookupConformance(
297+
actorType->getDeclaredInterfaceType(), actorProtocol);
298+
if (!actorConformance || actorConformance.isInvalid())
299+
return Type();
282300

283301
auto subs = SubstitutionMap::getProtocolSubstitutions(
284-
actorProtocol, actor->getDeclaredInterfaceType(), actorConformance);
302+
actorProtocol, actorType->getDeclaredInterfaceType(), actorConformance);
285303

286-
return Type(depMemTy).subst(subs);
304+
return memberTy.subst(subs)->getReducedType(sig);
287305
}
288306

289307
/******************************************************************************/

lib/Sema/CodeSynthesisDistributedActor.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -878,13 +878,9 @@ static bool canSynthesizeDistributedThunk(AbstractFunctionDecl *distributedTarge
878878
return true;
879879
}
880880

881-
SmallPtrSet<ProtocolDecl *, 2> requirementProtos;
882-
if (getSerializationRequirementTypesForMember(distributedTarget,
883-
requirementProtos)) {
884-
return true;
885-
}
886-
887-
return false;
881+
auto serializationTy =
882+
getDistributedActorSerializationType(distributedTarget->getDeclContext());
883+
return serializationTy && !serializationTy->hasDependentMember();
888884
}
889885

890886
/******************************************************************************/

lib/Sema/DerivedConformanceDistributedActor.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "TypeChecker.h"
2020
#include "swift/Strings.h"
2121
#include "TypeCheckDistributed.h"
22+
#include "swift/AST/ExistentialLayout.h"
2223
#include "swift/AST/NameLookupRequests.h"
2324
#include "swift/AST/ParameterList.h"
2425
#include "swift/AST/DistributedDecl.h"
@@ -219,7 +220,6 @@ static FuncDecl* createLocalFunc_doInvokeOnReturn(
219220
ParamDecl* handlerParam,
220221
ParamDecl* resultBufParam) {
221222
auto DC = parentFunc;
222-
auto DAS = C.getDistributedActorSystemDecl();
223223
auto doInvokeLocalFuncIdent = C.getIdentifier("doInvokeOnReturn");
224224

225225
// mock locations, we're a synthesized func and don't need real locations
@@ -243,7 +243,12 @@ static FuncDecl* createLocalFunc_doInvokeOnReturn(
243243
ParameterList::create(C, {resultTyParamDecl});
244244

245245
SmallVector<Requirement, 2> requirements;
246-
for (auto p : getDistributedSerializationRequirementProtocols(systemNominal, DAS)) {
246+
247+
auto serializationLayout =
248+
getDistributedActorSystemSerializationType(systemNominal)
249+
->getExistentialLayout();
250+
251+
for (auto p : serializationLayout.getProtocols()) {
247252
auto requirement =
248253
Requirement(RequirementKind::Conformance,
249254
resultGenericParamDecl->getDeclaredInterfaceType(),
@@ -285,7 +290,6 @@ deriveBodyDistributed_invokeHandlerOnReturn(AbstractFunctionDecl *afd,
285290
auto implicit = true;
286291
ASTContext &C = afd->getASTContext();
287292
auto DC = afd->getDeclContext();
288-
auto DAS = C.getDistributedActorSystemDecl();
289293

290294
// mock locations, we're a thunk and don't really need detailed locations
291295
const SourceLoc sloc = SourceLoc();
@@ -305,7 +309,7 @@ deriveBodyDistributed_invokeHandlerOnReturn(AbstractFunctionDecl *afd,
305309
auto metatypeParam = params->get(2);
306310

307311
auto serializationRequirementTypeTy =
308-
getDistributedSerializationRequirementType(nominal, DAS);
312+
getDistributedActorSystemSerializationType(nominal);
309313

310314
auto serializationRequirementMetaTypeTy =
311315
ExistentialMetatypeType::get(serializationRequirementTypeTy);
@@ -393,9 +397,6 @@ static FuncDecl *deriveDistributedActorSystem_invokeHandlerOnReturn(
393397
auto unsafeRawPointerType = C.getUnsafeRawPointerType();
394398
auto anyTypeType = ExistentialMetatypeType::get(C.TheAnyType); // Any.Type
395399

396-
// auto serializationRequirementType =
397-
// getDistributedSerializationRequirementType(system, DAS);
398-
399400
// params:
400401
// - handler: Self.ResultHandler
401402
// - resultBuffer:
@@ -583,7 +584,7 @@ deriveDistributedActorType_SerializationRequirement(
583584
return nullptr;
584585

585586
if (auto systemNominal = systemTy->getAnyNominal())
586-
return getDistributedSerializationRequirementType(systemNominal, DAS);
587+
return getDistributedActorSystemSerializationType(systemNominal);
587588

588589
return nullptr;
589590
}

0 commit comments

Comments
 (0)