Skip to content

Eliminate required type erasure from distributed actors #40033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -4417,6 +4417,8 @@ ERROR(actor_protocol_illegal_inheritance,none,
ERROR(distributed_actor_protocol_illegal_inheritance,none,
"non-distributed actor type %0 cannot conform to the 'DistributedActor' protocol",
(DeclName))
ERROR(broken_distributed_actor_requirement,none,
"DistributedActor protocol is broken: unexpected requirement", ())

ERROR(unowned_executor_outside_actor,none,
"'unownedExecutor' can only be implemented within the main "
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ IDENTIFIER(decode)
IDENTIFIER(decodeIfPresent)
IDENTIFIER(Decoder)
IDENTIFIER(decoder)
IDENTIFIER(DefaultActorTransport)
IDENTIFIER_(Differentiation)
IDENTIFIER_WITH_NAME(PatternMatchVar, "$match")
IDENTIFIER(dynamicallyCall)
Expand Down Expand Up @@ -141,6 +142,7 @@ IDENTIFIER_WITH_NAME(SwiftObject, "_TtCs12_SwiftObject")
IDENTIFIER(SwiftNativeNSObject)
IDENTIFIER(to)
IDENTIFIER(toRaw)
IDENTIFIER(Transport)
IDENTIFIER(Type)
IDENTIFIER(type)
IDENTIFIER(typeMismatch)
Expand Down
58 changes: 37 additions & 21 deletions lib/SILGen/SILGenDistributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,45 @@ static VarDecl* lookupProperty(NominalTypeDecl *ty, DeclName name) {
return dyn_cast<VarDecl>(refs.front());
}

/// Emit a reference to a specific stored property of the actor.
static SILValue emitActorPropertyReference(
SILGenFunction &SGF, SILLocation loc, SILValue actorSelf,
VarDecl *property) {
Type formalType = SGF.F.mapTypeIntoContext(property->getInterfaceType());
SILType loweredType = SGF.getLoweredType(formalType).getAddressType();
#if false
if (!loweredType.isAddress()) {
loweredType = SILType::getPrimitiveAddressType(
formalType->getCanonicalType());
}
#endif
return SGF.B.createRefElementAddr(loc, actorSelf, property, loweredType);
}

/// Perform an initializing store to the given property using the value
/// \param actorSelf the value representing `self` for the actor instance.
/// \param prop the property to be initialized.
/// \param value the value to use when initializing the property.
static void initializeProperty(SILGenFunction &SGF, SILLocation loc,
SILValue actorSelf,
VarDecl* prop, SILValue value) {
CanType propType = SGF.F.mapTypeIntoContext(prop->getInterfaceType())
->getCanonicalType();
SILType loweredPropType = SGF.getLoweredType(propType);
auto fieldAddr = SGF.B.createRefElementAddr(loc, actorSelf, prop, loweredPropType);
Type formalType = SGF.F.mapTypeIntoContext(prop->getInterfaceType());
SILType loweredType = SGF.getLoweredType(formalType);

auto fieldAddr = emitActorPropertyReference(SGF, loc, actorSelf, prop);

if (fieldAddr->getType().isAddress())
if (loweredType.isAddressOnly(SGF.F)) {
SGF.B.createCopyAddr(loc, value, fieldAddr, IsNotTake, IsInitialization);
else
SGF.B.emitStoreValueOperation(loc, value, fieldAddr, StoreOwnershipQualifier::Init);
} else {
if (value->getType().isAddress()) {
value = SGF.B.createTrivialLoadOr(
loc, value, LoadOwnershipQualifier::Copy);
}

value = SGF.B.emitCopyValueOperation(loc, value);
SGF.B.emitStoreValueOperation(
loc, value, fieldAddr, StoreOwnershipQualifier::Init);
}
}

/******************************************************************************/
Expand Down Expand Up @@ -229,7 +252,6 @@ void SILGenFunction::emitDistributedActorFactory(FuncDecl *fd) {
// ==== Prepare argument references
// --- Parameter: identity
SILArgument *identityArg = F.getArgument(0);
assert(identityArg->getType().getASTType()->isEqual(C.getAnyActorIdentityType()));

// --- Parameter: transport
SILArgument *transportArg = F.getArgument(1);
Expand Down Expand Up @@ -355,26 +377,20 @@ void SILGenFunction::emitResignIdentityCall(SILLocation loc,
FormalEvaluationScope scope(*this);

// ==== locate: self.id
auto *idVarDeclRef = lookupProperty(actorDecl, ctx.Id_id);
CanType idType = F.mapTypeIntoContext(
idVarDeclRef->getInterfaceType())->getCanonicalType();
auto idRef = B.createRefElementAddr(loc, actorSelf, idVarDeclRef,
getLoweredType(idType));
auto idRef = emitActorPropertyReference(
*this, loc, actorSelf.getValue(), lookupProperty(actorDecl, ctx.Id_id));

// ==== locate: self.actorTransport
auto transportVarDeclRef = lookupProperty(actorDecl, ctx.Id_actorTransport);
CanType transportType = F.mapTypeIntoContext(
transportVarDeclRef->getInterfaceType())->getCanonicalType();
auto transportRef =
B.createRefElementAddr(loc, actorSelf, transportVarDeclRef,
getLoweredType(transportType));
auto transportRef = emitActorPropertyReference(
*this, loc, actorSelf.getValue(),
lookupProperty(actorDecl, ctx.Id_actorTransport));

// Perform the call.
emitActorTransportWitnessCall(
B, loc, ctx.Id_resignIdentity,
transportRef.getValue(),
transportRef,
SILType(),
{ idRef.getValue() });
{ idRef });
}

void
Expand Down
13 changes: 4 additions & 9 deletions lib/SILOptimizer/Utils/DistributedActor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ void emitActorTransportWitnessCall(
if (methodSILFnTy->getSelfParameter().isFormalIndirect() &&
!transport->getType().isAddress()) {
auto buf = B.createAllocStack(loc, transport->getType(), None);
transport = B.emitCopyValueOperation(loc, transport);
B.emitStoreValueOperation(
loc, transport, buf, StoreOwnershipQualifier::Init);
temporaryTransportBuffer = SILValue(buf);
Expand Down Expand Up @@ -169,18 +170,12 @@ void emitActorTransportWitnessCall(
// If we had to create a buffer to pass the transport
if (temporaryTransportBuffer) {
emitCleanup([&](SILBuilder & builder) {
auto value = builder.emitLoadValueOperation(
loc, *temporaryTransportBuffer, LoadOwnershipQualifier::Take);
builder.emitDestroyValueOperation(loc, value);
builder.createDeallocStack(loc, *temporaryTransportBuffer);
});
}

// If we opened an existential, then destroy the existential.
#if false
if (openedExistential) {
emitCleanup([&](SILBuilder & builder) {
builder.emitDestroyAddr(loc, transport);
});
}
#endif
}

void emitActorReadyCall(SILBuilder &B, SILLocation loc, SILValue actor,
Expand Down
61 changes: 53 additions & 8 deletions lib/Sema/DerivedConformanceDistributedActor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ static FuncDecl *deriveDistributedActor_resolve(DerivedConformance &derived) {
return param;
};

Type addressType = C.getAnyActorIdentityDecl()->getDeclaredInterfaceType();
Type transportType = getDistributedActorTransportType(decl);
auto addressType = getDistributedActorIdentityType(decl);
auto transportType = getDistributedActorTransportType(decl);

// (_ identity: AnyActorIdentity, using transport: ActorTransport)
// (_ identity: Identity, using transport: ActorTransport)
auto *params = ParameterList::create(
C,
/*LParenLoc=*/SourceLoc(),
Expand All @@ -77,7 +77,7 @@ static FuncDecl *deriveDistributedActor_resolve(DerivedConformance &derived) {
// Func name: resolve(_:using:)
DeclName name(C, C.Id_resolve, params);

// Expected type: (Self) -> (AnyActorIdentity, ActorTransport) throws -> (Self)
// Expected type: (Self) -> (Identity, ActorTransport) throws -> (Self)
auto *factoryDecl =
FuncDecl::createImplicit(C, StaticSpellingKind::KeywordStatic,
name, SourceLoc(),
Expand Down Expand Up @@ -105,9 +105,9 @@ static ValueDecl *deriveDistributedActor_id(DerivedConformance &derived) {

// ```
// nonisolated
// let id: AnyActorIdentity
// let id: Identity
// ```
auto propertyType = C.getAnyActorIdentityDecl()->getDeclaredInterfaceType();
auto propertyType = getDistributedActorIdentityType(derived.Nominal);

VarDecl *propDecl;
PatternBindingDecl *pbDecl;
Expand All @@ -133,10 +133,11 @@ static ValueDecl *deriveDistributedActor_actorTransport(

// ```
// nonisolated
// let actorTransport: ActorTransport
// let actorTransport: Transport
// ```
// (no need for @actorIndependent because it is an immutable let)
Type propertyType = getDistributedActorTransportType(derived.Nominal);
auto propertyType = getDistributedActorTransportType(derived.Nominal);

VarDecl *propDecl;
PatternBindingDecl *pbDecl;
std::tie(propDecl, pbDecl) = derived.declareDerivedProperty(
Expand All @@ -154,6 +155,36 @@ static ValueDecl *deriveDistributedActor_actorTransport(
return propDecl;
}

static Type deriveDistributedActor_Transport(
DerivedConformance &derived) {
assert(derived.Nominal->isDistributedActor());
auto &C = derived.Context;

// Look for a type DefaultActorTransport within the parent context.
auto defaultTransportLookup = TypeChecker::lookupUnqualified(
derived.getConformanceContext()->getModuleScopeContext(),
DeclNameRef(C.Id_DefaultActorTransport),
derived.ConformanceDecl->getLoc());
TypeDecl *defaultTransportTypeDecl = nullptr;
for (const auto &found : defaultTransportLookup) {
if (auto foundType = dyn_cast_or_null<TypeDecl>(found.getValueDecl())) {
if (defaultTransportTypeDecl) {
// Note: ambiguity, for now just fail.
return nullptr;
}

defaultTransportTypeDecl = foundType;
continue;
}
}

// There is no default, so fail to synthesize.
if (!defaultTransportTypeDecl)
return nullptr;

// Return the default transport type.
return defaultTransportTypeDecl->getDeclaredInterfaceType();
}
// ==== ------------------------------------------------------------------------

ValueDecl *DerivedConformance::deriveDistributedActor(ValueDecl *requirement) {
Expand All @@ -174,3 +205,17 @@ ValueDecl *DerivedConformance::deriveDistributedActor(ValueDecl *requirement) {

return nullptr;
}

std::pair<Type, TypeDecl *> DerivedConformance::deriveDistributedActor(
AssociatedTypeDecl *assocType) {
if (!canDeriveDistributedActor(Nominal, cast<DeclContext>(ConformanceDecl)))
return std::make_pair(Type(), nullptr);

if (assocType->getName() == Context.Id_Transport) {
return std::make_pair(deriveDistributedActor_Transport(*this), nullptr);
}

Context.Diags.diagnose(assocType->getLoc(),
diag::broken_distributed_actor_requirement);
return std::make_pair(Type(), nullptr);
}
6 changes: 6 additions & 0 deletions lib/Sema/DerivedConformances.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ class DerivedConformance {
/// \returns the derived member, which will also be added to the type.
ValueDecl *deriveDistributedActor(ValueDecl *requirement);

/// Derive a DistributedActor associated type for a distributed actor.
///
/// \returns the derived type member, which will also be added to the type.
std::pair<Type, TypeDecl *> deriveDistributedActor(
AssociatedTypeDecl *assocType);

/// Determine if \c Actor can be derived for the given type.
static bool canDeriveActor(DeclContext *DC, NominalTypeDecl *NTD);

Expand Down
56 changes: 44 additions & 12 deletions lib/Sema/TypeCheckDistributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,20 +184,15 @@ void swift::checkDistributedActorConstructor(const ClassDecl *decl, ConstructorD
if (!ctor->isDesignatedInit())
return;

// === Designated initializers must accept exactly one ActorTransport
auto &C = ctor->getASTContext();
auto module = ctor->getParentModule();

// === Designated initializers must accept exactly one actor transport that
// matches the actor transport type of the actor.
SmallVector<ParamDecl*, 2> transportParams;
int transportParamsCount = 0;
auto protocolDecl = C.getProtocol(KnownProtocolKind::ActorTransport);
auto protocolTy = protocolDecl->getDeclaredInterfaceType();

Type transportTy = ctor->mapTypeIntoContext(
getDistributedActorTransportType(const_cast<ClassDecl *>(decl)));
for (auto param : *ctor->getParameters()) {
auto paramTy = ctor->mapTypeIntoContext(param->getInterfaceType());
auto conformance = TypeChecker::conformsToProtocol(paramTy, protocolDecl, module);

if (paramTy->isEqual(protocolTy) || !conformance.isInvalid()) {
if (paramTy->isEqual(transportTy)) {
transportParamsCount += 1;
transportParams.push_back(param);
}
Expand Down Expand Up @@ -252,9 +247,46 @@ Type swift::getDistributedActorTransportType(NominalTypeDecl *actor) {
assert(actor->isDistributedActor());
auto &ctx = actor->getASTContext();

auto protocol = ctx.getProtocol(KnownProtocolKind::ActorTransport);
auto protocol = ctx.getProtocol(KnownProtocolKind::DistributedActor);
if (!protocol)
return ErrorType::get(ctx);

return protocol->getDeclaredInterfaceType();
// Dig out the actor transport type.
auto module = actor->getParentModule();
Type selfType = actor->getSelfInterfaceType();
auto conformance = module->lookupConformance(selfType, protocol);
return conformance.getTypeWitnessByName(selfType, ctx.Id_Transport);
}

Type swift::getDistributedActorIdentityType(NominalTypeDecl *actor) {
assert(actor->isDistributedActor());
auto &ctx = actor->getASTContext();

auto actorProtocol = ctx.getProtocol(KnownProtocolKind::DistributedActor);
if (!actorProtocol)
return ErrorType::get(ctx);

AssociatedTypeDecl *transportDecl =
actorProtocol->getAssociatedType(ctx.Id_Transport);
if (!transportDecl)
return ErrorType::get(ctx);

auto transportProtocol = ctx.getProtocol(KnownProtocolKind::ActorTransport);
if (!transportProtocol)
return ErrorType::get(ctx);

AssociatedTypeDecl *identityDecl =
transportProtocol->getAssociatedType(ctx.getIdentifier("Identity"));
if (!identityDecl)
return ErrorType::get(ctx);

auto module = actor->getParentModule();
Type selfType = actor->getSelfInterfaceType();
auto conformance = module->lookupConformance(selfType, actorProtocol);
Type dependentType = actorProtocol->getSelfInterfaceType();
dependentType = DependentMemberType::get(dependentType, transportDecl);
dependentType = DependentMemberType::get(dependentType, identityDecl);
return dependentType.subst(
SubstitutionMap::getProtocolSubstitutions(
actorProtocol, selfType, conformance));
}
3 changes: 3 additions & 0 deletions lib/Sema/TypeCheckDistributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ bool checkDistributedFunction(FuncDecl *decl, bool diagnose);
/// Determine the distributed actor transport type for the given actor.
Type getDistributedActorTransportType(NominalTypeDecl *actor);

/// Determine the distributed actor identity type for the given actor.
Type getDistributedActorIdentityType(NominalTypeDecl *actor);

/// Diagnose a distributed func declaration in a not-distributed actor protocol.
void diagnoseDistributedFunctionInNonDistributedActorProtocol(
const ProtocolDecl *proto, InFlightDiagnostic &diag);
Expand Down
2 changes: 2 additions & 0 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6601,6 +6601,8 @@ TypeChecker::deriveTypeWitness(DeclContext *DC,
return std::make_pair(derived.deriveCaseIterable(AssocType), nullptr);
case KnownProtocolKind::Differentiable:
return derived.deriveDifferentiable(AssocType);
case KnownProtocolKind::DistributedActor:
return derived.deriveDistributedActor(AssocType);
default:
return std::make_pair(nullptr, nullptr);
}
Expand Down
Loading