Skip to content

[Distributed] Augment distributed accessor to lookup witness tables for decode call #41101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ IDENTIFIER(using)
IDENTIFIER(InvocationDecoder)
IDENTIFIER(whenLocal)
IDENTIFIER(decodeNextArgument)
IDENTIFIER(SerializationRequirement)

#undef IDENTIFIER
#undef IDENTIFIER_
Expand Down
67 changes: 66 additions & 1 deletion lib/IRGen/GenDistributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "swift/ABI/MetadataValues.h"
#include "swift/AST/ExtInfo.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/ProtocolConformanceRef.h"
#include "swift/IRGen/Linking.h"
#include "swift/SIL/SILFunction.h"
Expand Down Expand Up @@ -68,19 +69,44 @@ llvm::Value *irgen::emitDistributedActorInitializeRemote(
namespace {

struct ArgumentDecoderInfo {
/// The instance of the decoder this information belongs to.
llvm::Value *Decoder;

/// The type of `decodeNextArgument` method.
CanSILFunctionType MethodType;

/// The pointer to `decodeNextArgument` method which
/// could be used to form a call to it.
FunctionPointer MethodPtr;

/// Protocol requirements associated with the generic
/// parameter `Argument` of this decode method.
GenericSignature::RequiredProtocols ProtocolRequirements;

ArgumentDecoderInfo(llvm::Value *decoder, CanSILFunctionType decodeMethodTy,
FunctionPointer decodePtr)
: Decoder(decoder), MethodType(decodeMethodTy), MethodPtr(decodePtr) {}
: Decoder(decoder), MethodType(decodeMethodTy), MethodPtr(decodePtr),
ProtocolRequirements(findProtocolRequirements(decodeMethodTy)) {}

CanSILFunctionType getMethodType() const { return MethodType; }

ArrayRef<ProtocolDecl *> getProtocolRequirements() const {
return ProtocolRequirements;
}

/// Form a callee to a decode method - `decodeNextArgument`.
Callee getCallee() const;

private:
static GenericSignature::RequiredProtocols
findProtocolRequirements(CanSILFunctionType decodeMethodTy) {
auto signature = decodeMethodTy->getInvocationGenericSignature();
auto genericParams = signature.getGenericParams();

// func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
assert(genericParams.size() == 1);
return signature->getRequiredProtocols(genericParams.front());
}
};

class DistributedAccessor {
Expand Down Expand Up @@ -115,6 +141,10 @@ class DistributedAccessor {
llvm::Value *argumentType, const SILParameterInfo &param,
Explosion &arguments);

void lookupWitnessTables(llvm::Value *value,
ArrayRef<ProtocolDecl *> protocols,
Explosion &witnessTables);

/// Load witness table addresses (if any) from the given buffer
/// into the given argument explosion.
///
Expand Down Expand Up @@ -329,6 +359,10 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
// substitution Argument -> <argument metadata>
decodeArgs.add(argumentType);

// Lookup witness tables for the requirement on the argument type.
lookupWitnessTables(argumentType, decoder.getProtocolRequirements(),
decodeArgs);

Address calleeErrorSlot;
llvm::Value *decodeError = nullptr;

Expand Down Expand Up @@ -426,6 +460,37 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
}
}

void DistributedAccessor::lookupWitnessTables(
llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
Explosion &witnessTables) {
auto conformsToProtocol = IGM.getConformsToProtocolFn();

for (auto *protocol : protocols) {
auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor(protocol);
auto *witnessTable =
IGF.Builder.CreateCall(conformsToProtocol, {value, protocolDescriptor});

auto failBB = IGF.createBasicBlock("missing-witness");
auto contBB = IGF.createBasicBlock("");

auto isNull = IGF.Builder.CreateICmpEQ(
witnessTable, llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy));
IGF.Builder.CreateCondBr(isNull, failBB, contBB);

// This operation shouldn't fail because runtime should have checked that
// a particular argument type conforms to `SerializationRequirement`
// of the distributed actor the decoder is used for. If it does fail
// then accessor should trap.
{
IGF.Builder.emitBlock(failBB);
IGF.emitTrap("missing witness table", /*EmitUnreachable=*/true);
}

IGF.Builder.emitBlock(contBB);
witnessTables.add(witnessTable);
}
}

void DistributedAccessor::emitLoadOfWitnessTables(llvm::Value *witnessTables,
llvm::Value *numTables,
unsigned expectedWitnessTables,
Expand Down
56 changes: 48 additions & 8 deletions lib/Sema/TypeCheckDistributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,26 +472,66 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator,
auto members = TypeChecker::lookupMember(actor->getDeclContext(), decoderTy,
DeclNameRef(ctx.Id_decodeNextArgument));

// typealias SerializationRequirement = any ...
auto serializerType = getAssociatedTypeOfDistributedSystem(
actor, ctx.Id_SerializationRequirement)
->castTo<ExistentialType>()
->getConstraintType()
->getDesugaredType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, totally would have missed that last two calls


llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs;
if (auto composition = serializerType->getAs<ProtocolCompositionType>()) {
for (auto member : composition->getMembers()) {
if (auto *protocol = member->getAs<ProtocolType>())
serializationReqs.insert(protocol->getDecl());
}
} else {
auto protocol = serializerType->castTo<ProtocolType>()->getDecl();
serializationReqs.insert(protocol);
}

SmallVector<FuncDecl *, 2> candidates;
// Looking for `decodeNextArgument<Arg>() throws -> Arg`
// Looking for `decodeNextArgument<Arg: <SerializationReq>>() throws -> Arg`
for (auto &member : members) {
auto *FD = dyn_cast<FuncDecl>(member.getValueDecl());
if (!FD || FD->hasAsync() || !FD->hasThrows())
continue;

auto *params = FD->getParameters();
// No arguemnts.
if (params->size() != 0)
continue;

auto genericParamList = FD->getGenericParams();
if (genericParamList->size() == 1) {
auto paramTy = genericParamList->getParams()[0]
->getInterfaceType()
->getMetatypeInstanceType();
// A single generic parameter.
if (genericParamList->size() != 1)
continue;

if (FD->getResultInterfaceType()->isEqual(paramTy))
candidates.push_back(FD);
}
auto paramTy = genericParamList->getParams()[0]
->getInterfaceType()
->getMetatypeInstanceType();

// `decodeNextArgument` should return its generic parameter value
if (!FD->getResultInterfaceType()->isEqual(paramTy))
continue;

// Let's find out how many serialization requirements does this method cover
// e.g. `Codable` is two requirements - `Encodable` and `Decodable`.
unsigned numSerializationReqsCovered = llvm::count_if(
FD->getGenericRequirements(), [&](const Requirement &requirement) {
if (!(requirement.getFirstType()->isEqual(paramTy) &&
requirement.getKind() == RequirementKind::Conformance))
return 0;

return serializationReqs.count(requirement.getProtocolDecl()) ? 1 : 0;
});

// If the current method covers all of the serialization requirements,
// it's a match. Note that it might also have other requirements, but
// we let that go as long as there are no two candidates that differ
// only in generic requirements.
if (numSerializationReqsCovered == serializationReqs.size())
candidates.push_back(FD);
}

// Type-checker should reject any definition of invocation decoder
Expand Down
8 changes: 1 addition & 7 deletions stdlib/public/Distributed/DistributedActorSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ extension DistributedActorSystem {
let returnType = try invocationDecoder.decodeReturnType() ?? returnTypeFromTypeInfo
// let errorType = try invocation.decodeErrorType() // TODO: decide how to use?

var decoderAny = invocationDecoder as Any
// Execute the target!
try await _executeDistributedTarget(
on: actor,
Expand Down Expand Up @@ -419,12 +418,7 @@ public protocol DistributedTargetInvocationDecoder : AnyObject {
// /// buffer for all the arguments and their expected types. The 'pointer' passed here is a pointer
// /// to a "slot" in that pre-allocated buffer. That buffer will then be passed to a thunk that
// /// performs the actual distributed (local) instance method invocation.
// mutating func decodeNextArgument<Argument: SerializationRequirement>(
// into pointer: UnsafeMutablePointer<Argument> // pointer to our hbuffer
// ) throws

// FIXME(distributed): remove this since it must have the ': SerializationRequirement'
func decodeNextArgument<Argument>() throws -> Argument
// mutating func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument

func decodeErrorType() throws -> Any.Type?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ public class FakeInvocationDecoder : DistributedTargetInvocationDecoder {
return genericSubs
}

public func decodeNextArgument<Argument>() throws -> Argument {
public func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument {
guard argumentIndex < arguments.count else {
fatalError("Attempted to decode more arguments than stored! Index: \(argumentIndex), args: \(arguments)")
}
Expand Down
24 changes: 14 additions & 10 deletions test/Distributed/Inputs/dynamic_replacement_da_decl.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ struct ActorAddress: Hashable, Sendable, Codable {

final class FakeActorSystem: DistributedActorSystem {
typealias ActorID = ActorAddress
typealias InvocationDecoder = FakeInvocation
typealias InvocationEncoder = FakeInvocation
typealias InvocationDecoder = FakeInvocationDecoder
typealias InvocationEncoder = FakeInvocationEncoder
typealias SerializationRequirement = Codable

func resolve<Act>(id: ActorID, as actorType: Act.Type) throws -> Act?
Expand Down Expand Up @@ -82,19 +82,23 @@ struct FakeDistributedSystemError: DistributedActorSystemError {
let message: String
}

class FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvocationDecoder {
// === Sending / encoding -------------------------------------------------
struct FakeInvocationEncoder: DistributedTargetInvocationEncoder {
typealias SerializationRequirement = Codable

func recordGenericSubstitution<T>(_ type: T.Type) throws {}
func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {}
func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
func recordErrorType<E: Error>(_ type: E.Type) throws {}
func doneRecording() throws {}
mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {}
mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {}
mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
mutating func recordErrorType<E: Error>(_ type: E.Type) throws {}
mutating func doneRecording() throws {}
}

// === Receiving / decoding -------------------------------------------------
// === Receiving / decoding -------------------------------------------------
class FakeInvocationDecoder : DistributedTargetInvocationDecoder {
typealias SerializationRequirement = Codable

func decodeGenericSubstitutions() throws -> [Any.Type] { [] }
func decodeNextArgument<Argument>() throws -> Argument { fatalError() }
func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument { fatalError() }
func decodeReturnType() throws -> Any.Type? { nil }
func decodeErrorType() throws -> Any.Type? { nil }
}
Expand Down
2 changes: 1 addition & 1 deletion test/Distributed/Runtime/distributed_actor_decode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvoc
// === Receiving / decoding -------------------------------------------------

func decodeGenericSubstitutions() throws -> [Any.Type] { [] }
func decodeNextArgument<Argument>() throws -> Argument { fatalError() }
func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument { fatalError() }
func decodeReturnType() throws -> Any.Type? { nil }
func decodeErrorType() throws -> Any.Type? { nil }
}
Expand Down
2 changes: 1 addition & 1 deletion test/Distributed/Runtime/distributed_actor_deinit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class FakeDistributedInvocation: DistributedTargetInvocationEncoder, Distributed
func decodeGenericSubstitutions() throws -> [Any.Type] {
[]
}
func decodeNextArgument<Argument>() throws -> Argument {
func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument {
fatalError()
}
func decodeReturnType() throws -> Any.Type? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ struct ActorAddress: Sendable, Hashable, Codable {
//final class FakeActorSystem: DistributedActorSystem {
struct FakeActorSystem: DistributedActorSystem {
typealias ActorID = ActorAddress
typealias InvocationDecoder = FakeInvocation
typealias InvocationEncoder = FakeInvocation
typealias InvocationDecoder = FakeInvocationDecoder
typealias InvocationEncoder = FakeInvocationEncoder
typealias SerializationRequirement = Codable

init() {}
Expand Down Expand Up @@ -105,16 +105,17 @@ struct FakeActorSystem: DistributedActorSystem {

}

struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvocationDecoder {

struct FakeInvocationEncoder: DistributedTargetInvocationEncoder {
typealias SerializationRequirement = Codable

var types: [Any.Type] = []
var substitutions: [Any.Type] = []
var arguments: [Any] = []
var returnType: Any.Type? = nil
var errorType: Any.Type? = nil

mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {
types.append(type)
substitutions.append(type)
}
mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {
arguments.append(argument)
Expand All @@ -127,17 +128,45 @@ struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvo
}
mutating func doneRecording() throws {}

// === Receiving / decoding -------------------------------------------------
// For testing only
func makeDecoder() -> FakeInvocationDecoder {
return .init(
args: arguments,
substitutions: substitutions,
returnType: returnType,
errorType: errorType
)
}
}


class FakeInvocationDecoder : DistributedTargetInvocationDecoder {
typealias SerializationRequirement = Codable

var arguments: [Any] = []
var substitutions: [Any.Type] = []
var returnType: Any.Type? = nil
var errorType: Any.Type? = nil

init(
args: [Any],
substitutions: [Any.Type] = [],
returnType: Any.Type? = nil,
errorType: Any.Type? = nil
) {
self.arguments = args
self.substitutions = substitutions
self.returnType = returnType
self.errorType = errorType
}

// === Receiving / decoding -------------------------------------------------
func decodeGenericSubstitutions() throws -> [Any.Type] {
[]
return substitutions
}

var argumentIndex: Int = 0
mutating func decodeNextArgument<Argument>(
_ argumentType: Argument.Type,
into pointer: UnsafeMutablePointer<Argument>
) throws {
func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument {
guard argumentIndex < arguments.count else {
fatalError("Attempted to decode more arguments than stored! Index: \(argumentIndex), args: \(arguments)")
}
Expand All @@ -148,16 +177,18 @@ struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvo
}

print(" > decode argument: \(argument)")
pointer.pointee = argument
argumentIndex += 1
return argument
}

func decodeErrorType() throws -> Any.Type? {
self.errorType
public func decodeErrorType() throws -> Any.Type? {
print(" > decode return type: \(errorType.map { String(describing: $0) } ?? "nil")")
return self.errorType
}

func decodeReturnType() throws -> Any.Type? {
self.returnType
public func decodeReturnType() throws -> Any.Type? {
print(" > decode return type: \(returnType.map { String(describing: $0) } ?? "nil")")
return self.returnType
}
}

Expand Down
Loading