Skip to content

[Distributed] Re-implement ad-hoc requirements into dynamic witness table lookup for SerializationRequirement conformance #71435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
29a086d
[ConstraintSystem] NFC: Pass an underlying declaration to `Solution::…
xedin Feb 5, 2024
e59727a
[stdlib] Distributed: Promote `remoteCall{Void}` into protocol requir…
xedin Feb 6, 2024
f7e01fb
[TypeChecker] Allow missing requirements on `Res` generic parameter o…
xedin Feb 7, 2024
6afa732
[CSApply] Produce an abstract conformance for `Res` parameter of `Dis…
xedin Feb 7, 2024
c05497f
[IRGen] Implement dynamic witness table lookup for `Res` of `Distribu…
xedin Feb 7, 2024
171fb0a
[CSSimplify] Synthesize conformances for ad-hoc distributed witness r…
xedin Feb 7, 2024
0db2231
[ConstraintSystem] Record conformances synthesized for ad-hoc distrib…
xedin Feb 7, 2024
4f32111
[Sema/IRGen] Extend ad-hoc requirement handling to `DistributedTarget…
xedin Feb 7, 2024
4d45701
[Sema/IRGen] Extend ad-hoc requirement handling to `DistributedTarget…
xedin Feb 7, 2024
6c7000a
[Sema/IRGen] Extend ad-hoc requirement handling to `DistributedTarget…
xedin Feb 8, 2024
961aa30
[stdlib] Distributed: Remove `invokeOnReturn` requirement and its syn…
xedin Feb 8, 2024
a6a2a74
[SILOptimizer] Distributed: Suppress signature specialization for wit…
xedin Feb 8, 2024
cdc9a01
[Tests] NFC: Adjust distributed actor test-cases changed/improved due…
xedin Feb 8, 2024
a2caaa3
[Distributed] Promote SerializationRequirement as a primary associate…
xedin Feb 9, 2024
e85bd1f
[Sema] Distributed: Adjust distributed thunk synthesis to use witness…
xedin Feb 9, 2024
91b6bda
[IRGen] Distributed: Don't attempt to inject protocols without witnes…
xedin Feb 10, 2024
4d4c80b
[IRGen] Distributed: Always invoke `decodeNextArgument` through witne…
xedin Feb 12, 2024
1909b12
[SIL] Distributed: Remove logic related to ad-hoc requirements from S…
xedin Feb 12, 2024
0cc26cf
[AST] Distributed: Make sure that prospective `remoteCall` declaratio…
xedin Feb 12, 2024
aac4e85
[AST] WitnessMatching: Bring back original check for missing Sendable…
xedin Feb 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7498,6 +7498,17 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
/// 'DistributedTargetInvocationResultHandler' protocol.
bool isDistributedTargetInvocationResultHandlerOnReturn() const;

/// Determines whether this declaration is a witness to a
/// protocol requirement with ad-hoc `SerializationRequirement`
/// conformance.
bool isDistributedWitnessWithAdHocSerializationRequirement() const {
return isDistributedActorSystemRemoteCall(/*isVoidResult=*/false) ||
isDistributedTargetInvocationEncoderRecordArgument() ||
isDistributedTargetInvocationEncoderRecordReturnType() ||
isDistributedTargetInvocationDecoderDecodeNextArgument() ||
isDistributedTargetInvocationResultHandlerOnReturn();
}

/// For a method of a class, checks whether it will require a new entry in the
/// vtable.
bool needsNewVTableEntry() const;
Expand Down
32 changes: 0 additions & 32 deletions include/swift/SIL/SILFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,6 @@ class SILFunction
/// @_dynamicReplacement(for:) function.
SILFunction *ReplacedFunction = nullptr;

/// This SILFunction REFerences an ad-hoc protocol requirement witness in
/// order to keep it alive, such that it main be obtained in IRGen. Without
/// this explicit reference, the witness would seem not-used, and not be
/// accessible for IRGen.
///
/// Specifically, one such case is the DistributedTargetInvocationDecoder's
/// 'decodeNextArgument' which must be retained, as it is only used from IRGen
/// and such, appears as-if unused in SIL and would get optimized away.
// TODO: Consider making this a general "references adhoc functions" and make it an array?
SILFunction *RefAdHocRequirementFunction = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome to see this gone


Identifier ObjCReplacementFor;

/// The head of a single-linked list of currently alive BasicBlockBitfield.
Expand Down Expand Up @@ -596,27 +585,6 @@ class SILFunction
ReplacedFunction = nullptr;
}

SILFunction *getReferencedAdHocRequirementWitnessFunction() const {
return RefAdHocRequirementFunction;
}
// Marks that this `SILFunction` uses the passed in ad-hoc protocol
// requirement witness `f` and therefore must retain it explicitly,
// otherwise we might not be able to get a reference to it.
void setReferencedAdHocRequirementWitnessFunction(SILFunction *f) {
assert(RefAdHocRequirementFunction == nullptr && "already set");

if (f == nullptr)
return;
RefAdHocRequirementFunction = f;
RefAdHocRequirementFunction->incrementRefCount();
}
void dropReferencedAdHocRequirementWitnessFunction() {
if (!RefAdHocRequirementFunction)
return;
RefAdHocRequirementFunction->decrementRefCount();
RefAdHocRequirementFunction = nullptr;
}

bool hasObjCReplacement() const {
return !ObjCReplacementFor.empty();
}
Expand Down
31 changes: 31 additions & 0 deletions include/swift/Sema/ConstraintLocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,37 @@ class ConstraintLocatorBuilder {
return false;
}

std::optional<std::pair</*witness=*/ValueDecl *, GenericTypeParamType *>>
isForWitnessGenericParameterRequirement() const {
SmallVector<LocatorPathElt, 2> path;
getLocatorParts(path);

// -> witness -> generic env -> requirement
if (path.size() < 3)
return std::nullopt;

GenericTypeParamType *GP = nullptr;
if (auto reqLoc =
path.back().getAs<LocatorPathElt::TypeParameterRequirement>()) {
path.pop_back();
if (auto openedGeneric =
path.back().getAs<LocatorPathElt::OpenedGeneric>()) {
auto signature = openedGeneric->getSignature();
auto requirement = signature.getRequirements()[reqLoc->getIndex()];
GP = requirement.getFirstType()->getAs<GenericTypeParamType>();
}
}

if (!GP)
return std::nullopt;

auto witness = path.front().getAs<LocatorPathElt::Witness>();
if (!witness)
return std::nullopt;

return std::make_pair(witness->getDecl(), GP);
}

/// Checks whether this locator is describing an argument application for a
/// non-ephemeral parameter.
bool isNonEphemeralParameterApplication() const {
Expand Down
19 changes: 18 additions & 1 deletion include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,11 @@ class Solution {
llvm::DenseMap<ConstraintLocator *, UnresolvedDotExpr *>
ImplicitCallAsFunctionRoots;

/// The set of conformances synthesized during solving (i.e. for
/// ad-hoc distributed `SerializationRequirement` conformances).
llvm::MapVector<ConstraintLocator *, ProtocolConformanceRef>
SynthesizedConformances;

/// Record a new argument matching choice for given locator that maps a
/// single argument to a single parameter.
void recordSingleArgMatchingChoice(ConstraintLocator *locator);
Expand Down Expand Up @@ -1667,11 +1672,15 @@ class Solution {
/// Compute the set of substitutions for a generic signature opened at the
/// given locator.
///
/// \param decl The underlying declaration for which the substitutions are
/// computed.
///
/// \param sig The generic signature.
///
/// \param locator The locator that describes where the substitutions came
/// from.
SubstitutionMap computeSubstitutions(GenericSignature sig,
SubstitutionMap computeSubstitutions(NullablePtr<ValueDecl> decl,
GenericSignature sig,
ConstraintLocator *locator) const;

/// Resolves the contextual substitutions for a reference to a declaration
Expand Down Expand Up @@ -2411,6 +2420,11 @@ class ConstraintSystem {
llvm::SmallMapVector<ConstraintLocator *, UnresolvedDotExpr *, 2>
ImplicitCallAsFunctionRoots;

/// The set of conformances synthesized during solving (i.e. for
/// ad-hoc distributed `SerializationRequirement` conformances).
llvm::MapVector<ConstraintLocator *, ProtocolConformanceRef>
SynthesizedConformances;

private:
/// Describe the candidate expression for partial solving.
/// This class used by shrink & solve methods which apply
Expand Down Expand Up @@ -2934,6 +2948,9 @@ class ConstraintSystem {
/// The length of \c ImplicitCallAsFunctionRoots.
unsigned numImplicitCallAsFunctionRoots;

/// The length of \c SynthesizedConformances.
unsigned numSynthesizedConformances;

/// The previous score.
Score PreviousScore;

Expand Down
6 changes: 5 additions & 1 deletion lib/AST/DistributedDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ bool swift::checkDistributedSerializationRequirementIsExactlyCodable(
bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn) const {
auto &C = getASTContext();
auto module = getParentModule();
auto *DC = getDeclContext();

if (!DC->isTypeContext() || !isGeneric())
return false;

// === Check the name
auto callId = isVoidReturn ? C.Id_remoteCallVoid : C.Id_remoteCall;
Expand All @@ -398,7 +402,7 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
ProtocolDecl *systemProto =
C.getDistributedActorSystemDecl();

auto systemNominal = getDeclContext()->getSelfNominalTypeDecl();
auto systemNominal = DC->getSelfNominalTypeDecl();
auto distSystemConformance = module->lookupConformance(
systemNominal->getDeclaredInterfaceType(), systemProto);

Expand Down
135 changes: 16 additions & 119 deletions lib/IRGen/GenDistributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ struct ArgumentDecoderInfo {
/// The type of `decodeNextArgument` method.
CanSILFunctionType MethodType;

/// Protocol requirements associated with the generic
/// parameter `Argument` of this decode method.
GenericSignature::RequiredProtocols ProtocolRequirements;

// Witness metadata for conformance to DistributedTargetInvocationDecoder
// protocol.
WitnessMetadata Witness;
Expand All @@ -94,31 +90,19 @@ struct ArgumentDecoderInfo {
FunctionPointer decodeNextArgumentPtr,
CanSILFunctionType decodeNextArgumentTy)
: Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
MethodType(decodeNextArgumentTy),
ProtocolRequirements(findProtocolRequirements(decodeNextArgumentTy)) {
MethodType(decodeNextArgumentTy) {
Witness.SelfMetadata = decoderType;
Witness.SelfWitnessTable = decoderWitnessTable;
}

CanSILFunctionType getMethodType() const { return MethodType; }

ArrayRef<ProtocolDecl *> getProtocolRequirements() const {
return ProtocolRequirements;
WitnessMetadata *getWitnessMetadata() const {
return const_cast<WitnessMetadata *>(&Witness);
}

/// Form a callee to a decode method - `decodeNextArgument`.
Callee getCallee() const;

private:
static GenericSignature::RequiredProtocols
findProtocolRequirements(CanSILFunctionType decodeMethodTy) {
auto signature = decodeMethodTy->getInvocationGenericSignature();
auto genericParams = signature.getGenericParams();

// func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
assert(genericParams.size() == 1);
return signature->getRequiredProtocols(genericParams.front());
}
};

class DistributedAccessor {
Expand Down Expand Up @@ -156,10 +140,6 @@ class DistributedAccessor {
llvm::Value *argumentType, const SILParameterInfo &param,
Explosion &arguments);

void lookupWitnessTables(llvm::Value *value,
ArrayRef<ProtocolDecl *> protocols,
Explosion &witnessTables);

/// Load witness table addresses (if any) from the given buffer
/// into the given argument explosion.
///
Expand Down Expand Up @@ -417,17 +397,13 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
// substitution Argument -> <argument metadata>
decodeArgs.add(argumentType);

// Lookup witness tables for the requirement on the argument type.
lookupWitnessTables(argumentType, decoder.getProtocolRequirements(),
decodeArgs);

Address calleeErrorSlot;
llvm::Value *decodeError = nullptr;

emission->begin();
{
emission->setArgs(decodeArgs, /*isOutlined=*/false,
/*witnessMetadata=*/nullptr);
/*witnessMetadata=*/decoder.getWitnessMetadata());

Explosion result;
emission->emitToExplosion(result, /*isOutlined=*/false);
Expand Down Expand Up @@ -528,37 +504,6 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
}
}

void DistributedAccessor::lookupWitnessTables(
llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
Explosion &witnessTables) {
auto conformsToProtocol = IGM.getConformsToProtocolFunctionPointer();

for (auto *protocol : protocols) {
auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor(protocol);
auto *witnessTable =
IGF.Builder.CreateCall(conformsToProtocol, {value, protocolDescriptor});

auto failBB = IGF.createBasicBlock("missing-witness");
auto contBB = IGF.createBasicBlock("");

auto isNull = IGF.Builder.CreateICmpEQ(
witnessTable, llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy));
IGF.Builder.CreateCondBr(isNull, failBB, contBB);

// This operation shouldn't fail because runtime should have checked that
// a particular argument type conforms to `SerializationRequirement`
// of the distributed actor the decoder is used for. If it does fail
// then accessor should trap.
{
IGF.Builder.emitBlock(failBB);
IGF.emitTrap("missing witness table", /*EmitUnreachable=*/true);
}

IGF.Builder.emitBlock(contBB);
witnessTables.add(witnessTable);
}
}

void DistributedAccessor::emitLoadOfWitnessTables(llvm::Value *witnessTables,
llvm::Value *numTables,
unsigned expectedWitnessTables,
Expand Down Expand Up @@ -803,70 +748,22 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {

ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder(
llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
auto *actor = getDistributedActorOf(Target);
auto expansionContext = IGM.getMaximalTypeExpansionContext();

auto *decodeFn = IGM.Context.getDistributedActorArgumentDecodingMethod(actor);
assert(decodeFn && "no suitable decoder?");

auto methodTy = IGM.getSILTypes().getConstantFunctionType(
expansionContext, SILDeclRef(decodeFn));

auto fpKind = FunctionPointerKind::defaultAsync();
auto signature = IGM.getSignature(methodTy, fpKind);

// If the decoder class is `final`, let's emit a direct reference.
auto *decoderDecl = decodeFn->getDeclContext()->getSelfNominalTypeDecl();

// If decoder is a class, need to load it first because generic parameter
// is passed indirectly. This is good for structs and enums because
// `decodeNextArgument` is a mutating method, but not for classes because
// in that case heap object is mutated directly.
bool usesDispatchThunk = false;
auto &C = IGM.Context;

if (auto classDecl = dyn_cast<ClassDecl>(decoderDecl)) {
auto selfTy = methodTy->getSelfParameter().getSILStorageType(
IGM.getSILModule(), methodTy, expansionContext);
auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl();
SILDeclRef decodeNextArgumentRef(
decoderProtocol->getSingleRequirement(C.Id_decodeNextArgument));

auto &classTI = IGM.getTypeInfo(selfTy).as<ClassTypeInfo>();
auto &classLayout = classTI.getClassLayout(IGM, selfTy,
/*forBackwardDeployment=*/false);
llvm::Constant *fnPtr =
IGM.getAddrOfDispatchThunk(decodeNextArgumentRef, NotForDefinition);

llvm::Value *typedDecoderPtr = IGF.Builder.CreateBitCast(
decoder, classLayout.getType()->getPointerTo()->getPointerTo());

Explosion instance;

classTI.loadAsTake(IGF,
{typedDecoderPtr, classTI.getStorageType(),
classTI.getBestKnownAlignment()},
instance);

decoder = instance.claimNext();

/// When using library evolution functions have another "dispatch thunk"
/// so we must use this instead of the decodeFn directly.
usesDispatchThunk =
getMethodDispatch(decodeFn) == swift::MethodDispatch::Class &&
classDecl->hasResilientMetadata();
}

FunctionPointer methodPtr;

if (usesDispatchThunk) {
auto fnPtr = IGM.getAddrOfDispatchThunk(SILDeclRef(decodeFn), NotForDefinition);
methodPtr = FunctionPointer::createUnsigned(
methodTy, fnPtr, signature, /*useSignature=*/true);
} else {
SILFunction *decodeSILFn = IGM.getSILModule().lookUpFunction(SILDeclRef(decodeFn));
auto fnPtr = IGM.getAddrOfSILFunction(decodeSILFn, NotForDefinition,
/*isDynamicallyReplaceable=*/false);
methodPtr = FunctionPointer::forDirect(
classifyFunctionPointerKind(decodeSILFn), fnPtr,
/*secondaryValue=*/nullptr, signature);
}
auto fnType = IGM.getSILTypes().getConstantFunctionType(
IGM.getMaximalTypeExpansionContext(), decodeNextArgumentRef);

return {decoder, decoderTy, witnessTable, methodPtr, methodTy};
auto sig = IGM.getSignature(fnType);
auto fn = FunctionPointer::forDirect(fnType, fnPtr,
/*secondaryValue=*/nullptr, sig, true);
return {decoder, decoderTy, witnessTable, fn, fnType};
}

SILType DistributedAccessor::getResultType() const {
Expand Down
Loading