Skip to content

[DNM][Distributed] Handle generic actors and Codable better #71467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
3 changes: 3 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7458,6 +7458,9 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
getSILSynthesizeKind() == SILSynthesizeKind::DistributedActorFactory;
}

/// Determines whether this function is `DistributedActorSystem::remoteCall{Void}`.
bool isDistributedActorSystemRemoteCallRequirement(bool withResult) const;

/// Determines whether this function is a 'remoteCall' function,
/// which is used as ad-hoc protocol requirement by the
/// 'DistributedActorSystem' protocol.
Expand Down
6 changes: 5 additions & 1 deletion include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1653,11 +1653,15 @@ class Solution {
/// Compute the set of substitutions for a generic signature opened at the
/// given locator.
///
/// \param decl The underlying declaration for which the substitutions are
/// computed.
///
/// \param sig The generic signature.
///
/// \param locator The locator that describes where the substitutions came
/// from.
SubstitutionMap computeSubstitutions(GenericSignature sig,
SubstitutionMap computeSubstitutions(NullablePtr<ValueDecl> decl,
GenericSignature sig,
ConstraintLocator *locator) const;

/// Resolves the contextual substitutions for a reference to a declaration
Expand Down
33 changes: 33 additions & 0 deletions lib/AST/DistributedDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,39 @@ bool swift::checkDistributedSerializationRequirementIsExactlyCodable(
/********************* Ad-hoc protocol requirement checks *********************/
/******************************************************************************/

bool AbstractFunctionDecl::isDistributedActorSystemRemoteCallRequirement(
bool withResult) const {
auto &ctx = getASTContext();
auto *DC = getDeclContext();

auto expectedName =
withResult ? DeclName(ctx, ctx.Id_remoteCall,
{ctx.Id_on, ctx.Id_target, ctx.Id_invocation,
ctx.Id_throwing, ctx.Id_returning})
: DeclName(ctx, ctx.Id_remoteCallVoid,
{ctx.Id_on, ctx.Id_target, ctx.Id_invocation,
ctx.Id_throwing});

if (getName() != expectedName)
return false;

if (!DC->isTypeContext() || !isGeneric())
return false;

auto declaredIn = DC->getSelfProtocolDecl();
if (!(declaredIn && declaredIn == ctx.getDistributedActorSystemDecl()))
return false;

auto genericParams = getGenericParams();
if (genericParams->size() != (withResult ? 3 : 2))
return false;

if (!hasThrows() || !hasAsync())
return false;

return true;
}

bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn) const {
auto &C = getASTContext();
auto module = getParentModule();
Expand Down
45 changes: 45 additions & 0 deletions lib/IRGen/GenProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,14 @@ class EmitPolymorphicParameters : public PolymorphicConvention {
// Did the convention decide that the parameter at the given index
// was a class-pointer source?
bool isClassPointerSource(unsigned paramIndex);

// If we are building a protocol witness thunk for a
// `DistributedActorSystem.remoteCall` requirement we
// need to supply witness tables associated with `Res`
// generic parameter which are not expressible on the
// requirement because they come from `SerializationRequirement`
// associated type.
void injectAdHocDistributedRemoteCallRequirements();
};

} // end anonymous namespace
Expand Down Expand Up @@ -722,6 +730,40 @@ bool EmitPolymorphicParameters::isClassPointerSource(unsigned paramIndex) {
return false;
}

void EmitPolymorphicParameters::injectAdHocDistributedRemoteCallRequirements() {
// FIXME: We need a better way to recognize that function is
// a thunk for witness of `remoteCall` requirement.
if (!Fn.hasLocation())
return;

auto loc = Fn.getLocation();

auto *funcDecl = dyn_cast_or_null<FuncDecl>(loc.getAsDeclContext());
if (!(funcDecl && funcDecl->isDistributedActorSystemRemoteCall(
/*isVoidReturn=*/false)))
return;

auto sig = funcDecl->getGenericSignature();
auto resultInterfaceTy = funcDecl->getResultInterfaceType();
auto resultArchetypeTy =
getTypeInContext(resultInterfaceTy->getCanonicalType());
llvm::Value *resultMetadata = IGF.emitTypeMetadataRef(resultArchetypeTy);

auto resultRequirements = sig->getLocalRequirements(resultInterfaceTy);
for (auto *proto : resultRequirements.protos) {
// Lookup the witness table for this protocol dynamically via
// swift_conformsToProtocol(<<archetype>>, <<protocol>>)
auto *witnessTable = IGF.Builder.CreateCall(
IGM.getConformsToProtocolFunctionPointer(),
{resultMetadata, IGM.getAddrOfProtocolDescriptor(proto)});

IGF.setUnscopedLocalTypeData(
resultArchetypeTy,
LocalTypeDataKind::forAbstractProtocolWitnessTable(proto),
witnessTable);
}
}

namespace {

/// A class for binding type parameters of a generic function.
Expand Down Expand Up @@ -2574,6 +2616,9 @@ void EmitPolymorphicParameters::emit(EntryPointArgumentEmission &emission,

// Bind all the fulfillments we can from the formal parameters.
bindParameterSources(getParameter);

// Inject ad-hoc `remoteCall` requirements if any.
injectAdHocDistributedRemoteCallRequirements();
}

MetadataResponse
Expand Down
29 changes: 24 additions & 5 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ static bool isOpenedAnyObject(Type type) {
}

SubstitutionMap
Solution::computeSubstitutions(GenericSignature sig,
Solution::computeSubstitutions(NullablePtr<ValueDecl> decl,
GenericSignature sig,
ConstraintLocator *locator) const {
if (sig.isNull())
return SubstitutionMap();
Expand All @@ -103,6 +104,7 @@ Solution::computeSubstitutions(GenericSignature sig,
subs[opened.first] = type;
}

auto *DC = constraintSystem->DC;
auto lookupConformanceFn =
[&](CanType original, Type replacement,
ProtocolDecl *protoType) -> ProtocolConformanceRef {
Expand All @@ -113,8 +115,24 @@ Solution::computeSubstitutions(GenericSignature sig,
}

// FIXME: Retrieve the conformance from the solution itself.
return getConstraintSystem().DC->getParentModule()->checkConformance(
replacement, protoType);
auto conformance =
DC->getParentModule()->checkConformance(replacement, protoType);

if (conformance.isInvalid()) {
if (auto *funcDecl = dyn_cast<FuncDecl>(decl.getPtrOrNull())) {
if (funcDecl->isDistributedActorSystemRemoteCall(
/*isVoidResult=*/false)) {
// `Res` conformances would be looked by at runtime but are
// guaranteed to be there by Sema because all distributed
// methods and accessors are checked to conform to
// `SerializationRequirement` of `DistributedActorSystem`.
if (original->isEqual(funcDecl->getResultInterfaceType()))
return ProtocolConformanceRef(protoType);
}
}
}

return conformance;
};

return SubstitutionMap::get(sig,
Expand Down Expand Up @@ -147,7 +165,7 @@ Solution::resolveConcreteDeclRef(ValueDecl *decl,

// Get the generic signature of the decl and compute the substitutions.
auto sig = decl->getInnermostDeclContext()->getGenericSignatureOfContext();
auto subst = computeSubstitutions(sig, locator);
auto subst = computeSubstitutions(decl, sig, locator);

maybeInstantiateCXXMethodDefinition(decl);

Expand Down Expand Up @@ -7121,7 +7139,8 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
auto opaqueLocator = solution.getConstraintSystem().getOpenOpaqueLocator(
locator, opaqueDecl);
SubstitutionMap substitutions = solution.computeSubstitutions(
opaqueDecl->getOpaqueInterfaceGenericSignature(), opaqueLocator);
opaqueDecl, opaqueDecl->getOpaqueInterfaceGenericSignature(),
opaqueLocator);

// If we don't have substitutions, this is an opaque archetype from
// another declaration being manipulated, and not an erasure of a
Expand Down
44 changes: 34 additions & 10 deletions lib/Sema/CodeSynthesisDistributedActor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,25 @@

#include "TypeCheckDistributed.h"

#include "TypeChecker.h"
#include "DerivedConformances.h"
#include "TypeCheckType.h"
#include "TypeChecker.h"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/Availability.h"
#include "swift/AST/DistributedDecl.h"
#include "swift/AST/ExistentialLayout.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/Initializer.h"
#include "swift/AST/NameLookupRequests.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/NameLookupRequests.h"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/DistributedDecl.h"
#include "swift/Basic/Defer.h"
#include "swift/ClangImporter/ClangModule.h"
#include "swift/Sema/ConstraintSystem.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "DerivedConformances.h"

using namespace swift;

Expand Down Expand Up @@ -750,11 +751,8 @@ static FuncDecl *createDistributedThunkFunction(FuncDecl *func) {
/*********************** CODABLE CONFORMANCE **********************************/
/******************************************************************************/

static NormalProtocolConformance*
addDistributedActorCodableConformance(
ClassDecl *actor, ProtocolDecl *proto) {
assert(proto->isSpecificProtocol(swift::KnownProtocolKind::Decodable) ||
proto->isSpecificProtocol(swift::KnownProtocolKind::Encodable));
static NormalProtocolConformance *
addDistributedActorCodableConformance(ClassDecl *actor, ProtocolDecl *proto) {
auto &C = actor->getASTContext();
auto module = actor->getParentModule();

Expand All @@ -763,6 +761,32 @@ addDistributedActorCodableConformance(
return nullptr;
}

if (actor->isGeneric()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want isGeneric() or isGenericContext()?

auto idTy =
C.getAssociatedTypeOfDistributedSystemOfActor(actor, C.Id_ActorID);
if (idTy->hasError()) {
return nullptr;
}
auto encodableConf = module->lookupConformance(
idTy, C.getProtocol(swift::KnownProtocolKind::Encodable),
/*allowMissing=*/true);
auto decodableConf = module->lookupConformance(
idTy, C.getProtocol(swift::KnownProtocolKind::Decodable),
/*allowMissing=*/true);

// the system's ID is not codable, thus the actor isn't as well -- don't add
// the conformance.
if (encodableConf.isInvalid()) {
return nullptr;
}
if (decodableConf.isInvalid()) {
return nullptr;
}
}

assert(proto->isSpecificProtocol(swift::KnownProtocolKind::Decodable) ||
proto->isSpecificProtocol(swift::KnownProtocolKind::Encodable));

// === Does the actor explicitly conform to the protocol already?
auto explicitConformance =
module->lookupConformance(actor->getInterfaceType(), proto);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will always return false because getInterfaceType() returns a metatype given a TypeDecl

Expand Down
13 changes: 13 additions & 0 deletions lib/Sema/DerivedConformanceDistributedActor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,19 @@ deriveDistributedActorType_ActorSystem(
assert(derived.Nominal->isDistributedActor());
auto &C = derived.Context;


// If the actor is generic over ActorSystem,
// we don't need to synthesize the typealias.
if (derived.Nominal->getGenericSignature()) {
auto genSig = derived.Nominal->getGenericSignature();
for (auto param : genSig.getGenericParams()) {
if (param->getName() == C.Id_ActorSystem) {
return nullptr;
}
}
}


// Look for a type DefaultDistributedActorSystem within the parent context.
auto defaultDistributedActorSystemLookup = TypeChecker::lookupUnqualified(
derived.getConformanceContext()->getModuleScopeContext(),
Expand Down
57 changes: 48 additions & 9 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1195,13 +1195,52 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
}
bool requiresNonSendable = false;
if (!solution || solution->Fixes.size()) {
/// If the *only* problems are that `@Sendable` attributes are missing,
/// allow the match in some circumstances.
requiresNonSendable = solution
&& llvm::all_of(solution->Fixes, [](constraints::ConstraintFix *fix) {
return fix->getKind() == constraints::FixKind::AddSendableAttribute;
});
if (!requiresNonSendable)
auto isMatchedAllowed = [&](const constraints::Solution &solution) {
/// If the *only* problems are that `@Sendable` attributes are missing,
/// allow the match in some circumstances.
if (llvm::all_of(solution.Fixes, [](constraints::ConstraintFix *fix) {
return fix->getKind() ==
constraints::FixKind::AddSendableAttribute;
}))
return true;

auto *funcDecl = dyn_cast<AbstractFunctionDecl>(req);
// Conformance requirement between on `Res` and `SerializationRequirement`
// of `DistributedActorSystem.remoteCall` are not expressible at the moment
// but they are verified by Sema so it's okay to omit them here and lookup
// dynamically during IRGen.
if (funcDecl && funcDecl->isDistributedActorSystemRemoteCallRequirement(
/*withResult=*/true)) {
if (llvm::all_of(solution.Fixes, [&witness](constraints::ConstraintFix
*fix) {
auto conformance = dyn_cast<MissingConformance>(fix);
if (!conformance)
return false;

auto *locator = fix->getLocator();
auto requirement = locator->getLastElementAs<
LocatorPathElt::TypeParameterRequirement>();
if (!requirement)
return false;

auto signature =
locator->findLast<LocatorPathElt::OpenedGeneric>()
->getSignature();

auto subject =
signature.getRequirements()[requirement->getIndex()]
.getFirstType();
// `Res` is the result type so we can check against that.
return subject->isEqual(
cast<FuncDecl>(witness)->getResultInterfaceType());
}))
return true;
}
// In all other cases - disallow the match.
return false;
};

if (!solution || !isMatchedAllowed(*solution))
return RequirementMatch(witness, MatchKind::TypeConflict,
witnessType);
}
Expand Down Expand Up @@ -1231,8 +1270,8 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
auto witnessSig =
witness->getInnermostDeclContext()->getGenericSignatureOfContext();
result.WitnessSubstitutions =
solution->computeSubstitutions(witnessSig, witnessLocator);
solution->computeSubstitutions(witness, witnessSig, witnessLocator);

return result;
};

Expand Down
Loading