Skip to content

[AST/Sema] Distributed: Refactor type and member queries #72107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 0 additions & 73 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,68 +726,6 @@ class ASTContext final {
// Retrieve the declaration of Swift._stdlib_isOSVersionAtLeast.
FuncDecl *getIsOSVersionAtLeastDecl() const;

/// Retrieve the declaration of DistributedActorSystem.remoteCall(Void)(...).
///
/// \param actorOrSystem distributed actor or actor system to get the
/// remoteCall function for. Since the method we're looking for is an ad-hoc
/// requirement, a specific type MUST be passed here as it is not possible
/// to obtain the decl from just the `DistributedActorSystem` protocol type.
/// \param isVoidReturn true if the call will be returning `Void`.
AbstractFunctionDecl *getRemoteCallOnDistributedActorSystem(
NominalTypeDecl *actorOrSystem,
bool isVoidReturn) const;

/// Retrieve the declaration of DistributedActorSystem.make().
///
/// \param thunk the function from which we'll be invoking things on the obtained
/// actor system; This way we'll always get the right type, taking care of any
/// where clauses etc.
FuncDecl *getMakeInvocationEncoderOnDistributedActorSystem(
AbstractFunctionDecl *thunk) const;

// Retrieve the declaration of
// DistributedInvocationEncoder.recordGenericSubstitution(_:).
//
// \param nominal optionally provide a 'NominalTypeDecl' from which the
// function decl shall be extracted. This is useful to avoid witness calls
// through the protocol which is looked up when nominal is null.
FuncDecl *getRecordGenericSubstitutionOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;

// Retrieve the declaration of DistributedTargetInvocationEncoder.recordArgument(_:).
//
// \param nominal optionally provide a 'NominalTypeDecl' from which the
// function decl shall be extracted. This is useful to avoid witness calls
// through the protocol which is looked up when nominal is null.
AbstractFunctionDecl *getRecordArgumentOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;

// Retrieve the declaration of DistributedTargetInvocationEncoder.recordReturnType(_:).
AbstractFunctionDecl *getRecordReturnTypeOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;

// Retrieve the declaration of DistributedTargetInvocationEncoder.recordErrorType(_:).
AbstractFunctionDecl *getRecordErrorTypeOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;

// Retrieve the declaration of
// DistributedTargetInvocationDecoder.getDecodeNextArgumentOnDistributedInvocationDecoder(_:).
AbstractFunctionDecl *getDecodeNextArgumentOnDistributedInvocationDecoder(
NominalTypeDecl *nominal) const;

// Retrieve the declaration of
// getOnReturnOnDistributedTargetInvocationResultHandler.onReturn(_:).
AbstractFunctionDecl *getOnReturnOnDistributedTargetInvocationResultHandler(
NominalTypeDecl *nominal) const;

// Retrieve the declaration of DistributedInvocationEncoder.doneRecording().
//
// \param nominal optionally provide a 'NominalTypeDecl' from which the
// function decl shall be extracted. This is useful to avoid witness calls
// through the protocol which is looked up when nominal is null.
FuncDecl *getDoneRecordingOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;

/// Look for the declaration with the given name within the
/// passed in module.
void lookupInModule(ModuleDecl *M, StringRef name,
Expand Down Expand Up @@ -1527,17 +1465,6 @@ class ASTContext final {
/// alternative specified via the -entry-point-function-name frontend flag.
std::string getEntryPointFunctionName() const;

Type getAssociatedTypeOfDistributedSystemOfActor(NominalTypeDecl *actor,
Identifier member);

/// Find the concrete invocation decoder associated with the given actor.
NominalTypeDecl *
getDistributedActorInvocationDecoder(NominalTypeDecl *);

/// Find `decodeNextArgument<T>(type: T.Type) -> T` method associated with
/// invocation decoder of the given distributed actor.
FuncDecl *getDistributedActorArgumentDecodingMethod(NominalTypeDecl *);

/// The special Builtin.TheTupleType, which parents tuple extensions and
/// conformances.
BuiltinTupleDecl *getBuiltinTupleDecl();
Expand Down
97 changes: 80 additions & 17 deletions include/swift/AST/DistributedDecl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ class DeclContext;
class FuncDecl;
class NominalTypeDecl;

Type getAssociatedTypeOfDistributedSystemOfActor(DeclContext *actorOrExtension,
Identifier member);

/// Find the concrete invocation decoder associated with the given actor.
NominalTypeDecl *getDistributedActorInvocationDecoder(NominalTypeDecl *);

/// Find `decodeNextArgument<T>(type: T.Type) -> T` method associated with
/// invocation decoder of the given distributed actor.
FuncDecl *getDistributedActorArgumentDecodingMethod(NominalTypeDecl *);

/// Determine the concrete type of 'ActorSystem' as seen from the member.
/// E.g. when in a protocol, and trying to determine what the actor system was
/// constrained to.
Expand All @@ -47,12 +57,6 @@ Type getDistributedActorSystemType(NominalTypeDecl *actor);
/// Determine the `ID` type for the given actor.
Type getDistributedActorIDType(NominalTypeDecl *actor);

/// Similar to `getDistributedSerializationRequirementType`, however, from the
/// perspective of a concrete function. This way we're able to get the
/// serialization requirement for specific members, also in protocols.
Type getSerializationRequirementTypesForMember(
ValueDecl *member, llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements);

/// Get specific 'SerializationRequirement' as defined in 'nominal'
/// type, which must conform to the passed 'protocol' which is expected
/// to require the 'SerializationRequirement'.
Expand All @@ -66,6 +70,12 @@ AbstractFunctionDecl *
getAssociatedDistributedInvocationDecoderDecodeNextArgumentFunction(
ValueDecl *thunk);

Type getDistributedActorSerializationType(DeclContext *actorOrExtension);

/// Get the specific 'SerializationRequirement' type of a specific distributed
/// actor system.
Type getDistributedActorSystemSerializationType(NominalTypeDecl *system);

/// Get the specific 'InvocationEncoder' type of a specific distributed actor
/// system.
Type getDistributedActorSystemInvocationEncoderType(NominalTypeDecl *system);
Expand All @@ -81,17 +91,6 @@ Type getDistributedActorSystemResultHandlerType(NominalTypeDecl *system);
/// Get the 'ActorID' type of a specific distributed actor system.
Type getDistributedActorSystemActorIDType(NominalTypeDecl *system);

/// Get the specific protocols that the `SerializationRequirement` specifies,
/// and all parameters / return types of distributed targets must conform to.
///
/// E.g. if a system declares `typealias SerializationRequirement = Codable`
/// then this will return `{encodableProtocol, decodableProtocol}`.
///
/// Returns an empty set if the requirement was `Any`.
llvm::SmallPtrSet<ProtocolDecl *, 2>
getDistributedSerializationRequirementProtocols(
NominalTypeDecl *decl, ProtocolDecl* protocol);

/// Check if the `allRequirements` represent *exactly* the
/// `Encodable & Decodable` (also known as `Codable`) requirement.
///
Expand All @@ -115,6 +114,70 @@ getDistributedSerializationRequirements(
ProtocolDecl *protocol,
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos);

/// Retrieve the declaration of DistributedActorSystem.remoteCall(Void)(...).
///
/// \param actorOrSystem distributed actor or actor system to get the
/// remoteCall function for. Since the method we're looking for is an ad-hoc
/// requirement, a specific type MUST be passed here as it is not possible
/// to obtain the decl from just the `DistributedActorSystem` protocol type.
/// \param isVoidReturn true if the call will be returning `Void`.
AbstractFunctionDecl *
getRemoteCallOnDistributedActorSystem(NominalTypeDecl *actorOrSystem,
bool isVoidReturn);

/// Retrieve the declaration of DistributedActorSystem.make().
///
/// \param thunk the function from which we'll be invoking things on the
/// obtained actor system; This way we'll always get the right type, taking care
/// of any where clauses etc.
FuncDecl *
getMakeInvocationEncoderOnDistributedActorSystem(AbstractFunctionDecl *thunk);

// Retrieve the declaration of
// DistributedInvocationEncoder.recordGenericSubstitution(_:).
//
// \param nominal optionally provide a 'NominalTypeDecl' from which the
// function decl shall be extracted. This is useful to avoid witness calls
// through the protocol which is looked up when nominal is null.
FuncDecl *getRecordGenericSubstitutionOnDistributedInvocationEncoder(
NominalTypeDecl *nominal);

// Retrieve the declaration of
// DistributedTargetInvocationEncoder.recordArgument(_:).
//
// \param nominal optionally provide a 'NominalTypeDecl' from which the
// function decl shall be extracted. This is useful to avoid witness calls
// through the protocol which is looked up when nominal is null.
AbstractFunctionDecl *
getRecordArgumentOnDistributedInvocationEncoder(NominalTypeDecl *nominal);

// Retrieve the declaration of
// DistributedTargetInvocationEncoder.recordReturnType(_:).
AbstractFunctionDecl *
getRecordReturnTypeOnDistributedInvocationEncoder(NominalTypeDecl *nominal);

// Retrieve the declaration of
// DistributedTargetInvocationEncoder.recordErrorType(_:).
AbstractFunctionDecl *
getRecordErrorTypeOnDistributedInvocationEncoder(NominalTypeDecl *nominal);

// Retrieve the declaration of
// DistributedTargetInvocationDecoder.getDecodeNextArgumentOnDistributedInvocationDecoder(_:).
AbstractFunctionDecl *
getDecodeNextArgumentOnDistributedInvocationDecoder(NominalTypeDecl *nominal);

// Retrieve the declaration of
// getOnReturnOnDistributedTargetInvocationResultHandler.onReturn(_:).
AbstractFunctionDecl *
getOnReturnOnDistributedTargetInvocationResultHandler(NominalTypeDecl *nominal);

// Retrieve the declaration of DistributedInvocationEncoder.doneRecording().
//
// \param nominal optionally provide a 'NominalTypeDecl' from which the
// function decl shall be extracted. This is useful to avoid witness calls
// through the protocol which is looked up when nominal is null.
FuncDecl *
getDoneRecordingOnDistributedInvocationEncoder(NominalTypeDecl *nominal);
}

#endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */
113 changes: 0 additions & 113 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1545,119 +1545,6 @@ FuncDecl *ASTContext::getEqualIntDecl() const {
return getBinaryComparisonOperatorIntDecl(*this, "==", getImpl().EqualIntDecl);
}

FuncDecl *ASTContext::getMakeInvocationEncoderOnDistributedActorSystem(
AbstractFunctionDecl *thunk) const {
auto systemTy = getConcreteReplacementForProtocolActorSystemType(thunk);
assert(systemTy && "No specific ActorSystem type found!");

auto systemNominal = systemTy->getNominalOrBoundGenericNominal();
assert(systemNominal && "No system nominal type found!");

for (auto result : systemNominal->lookupDirect(Id_makeInvocationEncoder)) {
auto *func = dyn_cast<FuncDecl>(result);
if (func && func->isDistributedActorSystemMakeInvocationEncoder()) {
return func;
}
}

return nullptr;
}

FuncDecl *
ASTContext::getRecordGenericSubstitutionOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const {
if (!nominal)
return nullptr;

for (auto result : nominal->lookupDirect(Id_recordGenericSubstitution)) {
auto *func = dyn_cast<FuncDecl>(result);
if (func &&
func->isDistributedTargetInvocationEncoderRecordGenericSubstitution()) {
return func;
}
}

return nullptr;
}

AbstractFunctionDecl *ASTContext::getRecordArgumentOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const {
if (!nominal)
return nullptr;

return evaluateOrDefault(
nominal->getASTContext().evaluator,
GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest{nominal},
nullptr);
}

AbstractFunctionDecl *ASTContext::getRecordReturnTypeOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const {
if (!nominal)
return nullptr;

return evaluateOrDefault(
nominal->getASTContext().evaluator,
GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest{nominal},
nullptr);
}

AbstractFunctionDecl *ASTContext::getRecordErrorTypeOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const {
if (!nominal)
return nullptr;

return evaluateOrDefault(
nominal->getASTContext().evaluator,
GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest{nominal},
nullptr);
}

AbstractFunctionDecl *ASTContext::getDecodeNextArgumentOnDistributedInvocationDecoder(
NominalTypeDecl *nominal) const {
if (!nominal)
return nullptr;

return evaluateOrDefault(
nominal->getASTContext().evaluator,
GetDistributedTargetInvocationDecoderDecodeNextArgumentFunctionRequest{nominal},
nullptr);
}

AbstractFunctionDecl *ASTContext::getOnReturnOnDistributedTargetInvocationResultHandler(
NominalTypeDecl *nominal) const {
if (!nominal)
return nullptr;

return evaluateOrDefault(
nominal->getASTContext().evaluator,
GetDistributedTargetInvocationResultHandlerOnReturnFunctionRequest{nominal},
nullptr);
}

FuncDecl *ASTContext::getDoneRecordingOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const {

llvm::SmallVector<ValueDecl *, 2> results;
nominal->lookupQualified(nominal, DeclNameRef(Id_doneRecording),
SourceLoc(), NL_QualifiedDefault, results);
for (auto result : results) {
auto *fd = dyn_cast<FuncDecl>(result);
if (!fd)
continue;

if (fd->getParameters()->size() != 0)
continue;

if (fd->getResultInterfaceType()->isVoid() &&
fd->hasThrows() &&
!fd->hasAsync())
return fd;
}

return nullptr;
}

FuncDecl *ASTContext::getHashValueForDecl() const {
if (getImpl().HashValueForDecl)
return getImpl().HashValueForDecl;
Expand Down
Loading