Skip to content

Commit 9ea2f96

Browse files
authored
Merge pull request #41101 from xedin/load-witness-tables-for-decodeNextArgument
[Distributed] Augment distributed accessor to lookup witness tables for decode call
2 parents ea997a0 + 9694597 commit 9ea2f96

27 files changed

+248
-88
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ IDENTIFIER(using)
292292
IDENTIFIER(InvocationDecoder)
293293
IDENTIFIER(whenLocal)
294294
IDENTIFIER(decodeNextArgument)
295+
IDENTIFIER(SerializationRequirement)
295296

296297
#undef IDENTIFIER
297298
#undef IDENTIFIER_

lib/IRGen/GenDistributed.cpp

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "swift/ABI/MetadataValues.h"
3737
#include "swift/AST/ExtInfo.h"
3838
#include "swift/AST/GenericEnvironment.h"
39+
#include "swift/AST/GenericSignature.h"
3940
#include "swift/AST/ProtocolConformanceRef.h"
4041
#include "swift/IRGen/Linking.h"
4142
#include "swift/SIL/SILFunction.h"
@@ -68,19 +69,44 @@ llvm::Value *irgen::emitDistributedActorInitializeRemote(
6869
namespace {
6970

7071
struct ArgumentDecoderInfo {
72+
/// The instance of the decoder this information belongs to.
7173
llvm::Value *Decoder;
7274

75+
/// The type of `decodeNextArgument` method.
7376
CanSILFunctionType MethodType;
77+
78+
/// The pointer to `decodeNextArgument` method which
79+
/// could be used to form a call to it.
7480
FunctionPointer MethodPtr;
7581

82+
/// Protocol requirements associated with the generic
83+
/// parameter `Argument` of this decode method.
84+
GenericSignature::RequiredProtocols ProtocolRequirements;
85+
7686
ArgumentDecoderInfo(llvm::Value *decoder, CanSILFunctionType decodeMethodTy,
7787
FunctionPointer decodePtr)
78-
: Decoder(decoder), MethodType(decodeMethodTy), MethodPtr(decodePtr) {}
88+
: Decoder(decoder), MethodType(decodeMethodTy), MethodPtr(decodePtr),
89+
ProtocolRequirements(findProtocolRequirements(decodeMethodTy)) {}
7990

8091
CanSILFunctionType getMethodType() const { return MethodType; }
8192

93+
ArrayRef<ProtocolDecl *> getProtocolRequirements() const {
94+
return ProtocolRequirements;
95+
}
96+
8297
/// Form a callee to a decode method - `decodeNextArgument`.
8398
Callee getCallee() const;
99+
100+
private:
101+
static GenericSignature::RequiredProtocols
102+
findProtocolRequirements(CanSILFunctionType decodeMethodTy) {
103+
auto signature = decodeMethodTy->getInvocationGenericSignature();
104+
auto genericParams = signature.getGenericParams();
105+
106+
// func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
107+
assert(genericParams.size() == 1);
108+
return signature->getRequiredProtocols(genericParams.front());
109+
}
84110
};
85111

86112
class DistributedAccessor {
@@ -115,6 +141,10 @@ class DistributedAccessor {
115141
llvm::Value *argumentType, const SILParameterInfo &param,
116142
Explosion &arguments);
117143

144+
void lookupWitnessTables(llvm::Value *value,
145+
ArrayRef<ProtocolDecl *> protocols,
146+
Explosion &witnessTables);
147+
118148
/// Load witness table addresses (if any) from the given buffer
119149
/// into the given argument explosion.
120150
///
@@ -329,6 +359,10 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
329359
// substitution Argument -> <argument metadata>
330360
decodeArgs.add(argumentType);
331361

362+
// Lookup witness tables for the requirement on the argument type.
363+
lookupWitnessTables(argumentType, decoder.getProtocolRequirements(),
364+
decodeArgs);
365+
332366
Address calleeErrorSlot;
333367
llvm::Value *decodeError = nullptr;
334368

@@ -426,6 +460,37 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
426460
}
427461
}
428462

463+
void DistributedAccessor::lookupWitnessTables(
464+
llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
465+
Explosion &witnessTables) {
466+
auto conformsToProtocol = IGM.getConformsToProtocolFn();
467+
468+
for (auto *protocol : protocols) {
469+
auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor(protocol);
470+
auto *witnessTable =
471+
IGF.Builder.CreateCall(conformsToProtocol, {value, protocolDescriptor});
472+
473+
auto failBB = IGF.createBasicBlock("missing-witness");
474+
auto contBB = IGF.createBasicBlock("");
475+
476+
auto isNull = IGF.Builder.CreateICmpEQ(
477+
witnessTable, llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy));
478+
IGF.Builder.CreateCondBr(isNull, failBB, contBB);
479+
480+
// This operation shouldn't fail because runtime should have checked that
481+
// a particular argument type conforms to `SerializationRequirement`
482+
// of the distributed actor the decoder is used for. If it does fail
483+
// then accessor should trap.
484+
{
485+
IGF.Builder.emitBlock(failBB);
486+
IGF.emitTrap("missing witness table", /*EmitUnreachable=*/true);
487+
}
488+
489+
IGF.Builder.emitBlock(contBB);
490+
witnessTables.add(witnessTable);
491+
}
492+
}
493+
429494
void DistributedAccessor::emitLoadOfWitnessTables(llvm::Value *witnessTables,
430495
llvm::Value *numTables,
431496
unsigned expectedWitnessTables,

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -472,26 +472,66 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator,
472472
auto members = TypeChecker::lookupMember(actor->getDeclContext(), decoderTy,
473473
DeclNameRef(ctx.Id_decodeNextArgument));
474474

475+
// typealias SerializationRequirement = any ...
476+
auto serializerType = getAssociatedTypeOfDistributedSystem(
477+
actor, ctx.Id_SerializationRequirement)
478+
->castTo<ExistentialType>()
479+
->getConstraintType()
480+
->getDesugaredType();
481+
482+
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs;
483+
if (auto composition = serializerType->getAs<ProtocolCompositionType>()) {
484+
for (auto member : composition->getMembers()) {
485+
if (auto *protocol = member->getAs<ProtocolType>())
486+
serializationReqs.insert(protocol->getDecl());
487+
}
488+
} else {
489+
auto protocol = serializerType->castTo<ProtocolType>()->getDecl();
490+
serializationReqs.insert(protocol);
491+
}
492+
475493
SmallVector<FuncDecl *, 2> candidates;
476-
// Looking for `decodeNextArgument<Arg>() throws -> Arg`
494+
// Looking for `decodeNextArgument<Arg: <SerializationReq>>() throws -> Arg`
477495
for (auto &member : members) {
478496
auto *FD = dyn_cast<FuncDecl>(member.getValueDecl());
479497
if (!FD || FD->hasAsync() || !FD->hasThrows())
480498
continue;
481499

482500
auto *params = FD->getParameters();
501+
// No arguemnts.
483502
if (params->size() != 0)
484503
continue;
485504

486505
auto genericParamList = FD->getGenericParams();
487-
if (genericParamList->size() == 1) {
488-
auto paramTy = genericParamList->getParams()[0]
489-
->getInterfaceType()
490-
->getMetatypeInstanceType();
506+
// A single generic parameter.
507+
if (genericParamList->size() != 1)
508+
continue;
491509

492-
if (FD->getResultInterfaceType()->isEqual(paramTy))
493-
candidates.push_back(FD);
494-
}
510+
auto paramTy = genericParamList->getParams()[0]
511+
->getInterfaceType()
512+
->getMetatypeInstanceType();
513+
514+
// `decodeNextArgument` should return its generic parameter value
515+
if (!FD->getResultInterfaceType()->isEqual(paramTy))
516+
continue;
517+
518+
// Let's find out how many serialization requirements does this method cover
519+
// e.g. `Codable` is two requirements - `Encodable` and `Decodable`.
520+
unsigned numSerializationReqsCovered = llvm::count_if(
521+
FD->getGenericRequirements(), [&](const Requirement &requirement) {
522+
if (!(requirement.getFirstType()->isEqual(paramTy) &&
523+
requirement.getKind() == RequirementKind::Conformance))
524+
return 0;
525+
526+
return serializationReqs.count(requirement.getProtocolDecl()) ? 1 : 0;
527+
});
528+
529+
// If the current method covers all of the serialization requirements,
530+
// it's a match. Note that it might also have other requirements, but
531+
// we let that go as long as there are no two candidates that differ
532+
// only in generic requirements.
533+
if (numSerializationReqsCovered == serializationReqs.size())
534+
candidates.push_back(FD);
495535
}
496536

497537
// Type-checker should reject any definition of invocation decoder

stdlib/public/Distributed/DistributedActorSystem.swift

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ extension DistributedActorSystem {
285285
let returnType = try invocationDecoder.decodeReturnType() ?? returnTypeFromTypeInfo
286286
// let errorType = try invocation.decodeErrorType() // TODO: decide how to use?
287287

288-
var decoderAny = invocationDecoder as Any
289288
// Execute the target!
290289
try await _executeDistributedTarget(
291290
on: actor,
@@ -419,12 +418,7 @@ public protocol DistributedTargetInvocationDecoder : AnyObject {
419418
// /// buffer for all the arguments and their expected types. The 'pointer' passed here is a pointer
420419
// /// to a "slot" in that pre-allocated buffer. That buffer will then be passed to a thunk that
421420
// /// performs the actual distributed (local) instance method invocation.
422-
// mutating func decodeNextArgument<Argument: SerializationRequirement>(
423-
// into pointer: UnsafeMutablePointer<Argument> // pointer to our hbuffer
424-
// ) throws
425-
426-
// FIXME(distributed): remove this since it must have the ': SerializationRequirement'
427-
func decodeNextArgument<Argument>() throws -> Argument
421+
// mutating func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument
428422

429423
func decodeErrorType() throws -> Any.Type?
430424

test/Distributed/Inputs/FakeDistributedActorSystems.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ public class FakeInvocationDecoder : DistributedTargetInvocationDecoder {
311311
return genericSubs
312312
}
313313

314-
public func decodeNextArgument<Argument>() throws -> Argument {
314+
public func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument {
315315
guard argumentIndex < arguments.count else {
316316
fatalError("Attempted to decode more arguments than stored! Index: \(argumentIndex), args: \(arguments)")
317317
}

test/Distributed/Inputs/dynamic_replacement_da_decl.swift

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ struct ActorAddress: Hashable, Sendable, Codable {
3636

3737
final class FakeActorSystem: DistributedActorSystem {
3838
typealias ActorID = ActorAddress
39-
typealias InvocationDecoder = FakeInvocation
40-
typealias InvocationEncoder = FakeInvocation
39+
typealias InvocationDecoder = FakeInvocationDecoder
40+
typealias InvocationEncoder = FakeInvocationEncoder
4141
typealias SerializationRequirement = Codable
4242

4343
func resolve<Act>(id: ActorID, as actorType: Act.Type) throws -> Act?
@@ -82,19 +82,23 @@ struct FakeDistributedSystemError: DistributedActorSystemError {
8282
let message: String
8383
}
8484

85-
class FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvocationDecoder {
85+
// === Sending / encoding -------------------------------------------------
86+
struct FakeInvocationEncoder: DistributedTargetInvocationEncoder {
8687
typealias SerializationRequirement = Codable
8788

88-
func recordGenericSubstitution<T>(_ type: T.Type) throws {}
89-
func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {}
90-
func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
91-
func recordErrorType<E: Error>(_ type: E.Type) throws {}
92-
func doneRecording() throws {}
89+
mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {}
90+
mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {}
91+
mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
92+
mutating func recordErrorType<E: Error>(_ type: E.Type) throws {}
93+
mutating func doneRecording() throws {}
94+
}
9395

94-
// === Receiving / decoding -------------------------------------------------
96+
// === Receiving / decoding -------------------------------------------------
97+
class FakeInvocationDecoder : DistributedTargetInvocationDecoder {
98+
typealias SerializationRequirement = Codable
9599

96100
func decodeGenericSubstitutions() throws -> [Any.Type] { [] }
97-
func decodeNextArgument<Argument>() throws -> Argument { fatalError() }
101+
func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument { fatalError() }
98102
func decodeReturnType() throws -> Any.Type? { nil }
99103
func decodeErrorType() throws -> Any.Type? { nil }
100104
}

test/Distributed/Runtime/distributed_actor_decode.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvoc
113113
// === Receiving / decoding -------------------------------------------------
114114

115115
func decodeGenericSubstitutions() throws -> [Any.Type] { [] }
116-
func decodeNextArgument<Argument>() throws -> Argument { fatalError() }
116+
func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument { fatalError() }
117117
func decodeReturnType() throws -> Any.Type? { nil }
118118
func decodeErrorType() throws -> Any.Type? { nil }
119119
}

test/Distributed/Runtime/distributed_actor_deinit.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class FakeDistributedInvocation: DistributedTargetInvocationEncoder, Distributed
133133
func decodeGenericSubstitutions() throws -> [Any.Type] {
134134
[]
135135
}
136-
func decodeNextArgument<Argument>() throws -> Argument {
136+
func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument {
137137
fatalError()
138138
}
139139
func decodeReturnType() throws -> Any.Type? {

test/Distributed/Runtime/distributed_actor_func_calls_remoteCall_generic.swift

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ struct ActorAddress: Sendable, Hashable, Codable {
4545
//final class FakeActorSystem: DistributedActorSystem {
4646
struct FakeActorSystem: DistributedActorSystem {
4747
typealias ActorID = ActorAddress
48-
typealias InvocationDecoder = FakeInvocation
49-
typealias InvocationEncoder = FakeInvocation
48+
typealias InvocationDecoder = FakeInvocationDecoder
49+
typealias InvocationEncoder = FakeInvocationEncoder
5050
typealias SerializationRequirement = Codable
5151

5252
init() {}
@@ -105,16 +105,17 @@ struct FakeActorSystem: DistributedActorSystem {
105105

106106
}
107107

108-
struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvocationDecoder {
108+
109+
struct FakeInvocationEncoder: DistributedTargetInvocationEncoder {
109110
typealias SerializationRequirement = Codable
110111

111-
var types: [Any.Type] = []
112+
var substitutions: [Any.Type] = []
112113
var arguments: [Any] = []
113114
var returnType: Any.Type? = nil
114115
var errorType: Any.Type? = nil
115116

116117
mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {
117-
types.append(type)
118+
substitutions.append(type)
118119
}
119120
mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {
120121
arguments.append(argument)
@@ -127,17 +128,45 @@ struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvo
127128
}
128129
mutating func doneRecording() throws {}
129130

130-
// === Receiving / decoding -------------------------------------------------
131+
// For testing only
132+
func makeDecoder() -> FakeInvocationDecoder {
133+
return .init(
134+
args: arguments,
135+
substitutions: substitutions,
136+
returnType: returnType,
137+
errorType: errorType
138+
)
139+
}
140+
}
141+
142+
143+
class FakeInvocationDecoder : DistributedTargetInvocationDecoder {
144+
typealias SerializationRequirement = Codable
131145

146+
var arguments: [Any] = []
147+
var substitutions: [Any.Type] = []
148+
var returnType: Any.Type? = nil
149+
var errorType: Any.Type? = nil
150+
151+
init(
152+
args: [Any],
153+
substitutions: [Any.Type] = [],
154+
returnType: Any.Type? = nil,
155+
errorType: Any.Type? = nil
156+
) {
157+
self.arguments = args
158+
self.substitutions = substitutions
159+
self.returnType = returnType
160+
self.errorType = errorType
161+
}
162+
163+
// === Receiving / decoding -------------------------------------------------
132164
func decodeGenericSubstitutions() throws -> [Any.Type] {
133-
[]
165+
return substitutions
134166
}
135167

136168
var argumentIndex: Int = 0
137-
mutating func decodeNextArgument<Argument>(
138-
_ argumentType: Argument.Type,
139-
into pointer: UnsafeMutablePointer<Argument>
140-
) throws {
169+
func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument {
141170
guard argumentIndex < arguments.count else {
142171
fatalError("Attempted to decode more arguments than stored! Index: \(argumentIndex), args: \(arguments)")
143172
}
@@ -148,16 +177,18 @@ struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvo
148177
}
149178

150179
print(" > decode argument: \(argument)")
151-
pointer.pointee = argument
152180
argumentIndex += 1
181+
return argument
153182
}
154183

155-
func decodeErrorType() throws -> Any.Type? {
156-
self.errorType
184+
public func decodeErrorType() throws -> Any.Type? {
185+
print(" > decode return type: \(errorType.map { String(describing: $0) } ?? "nil")")
186+
return self.errorType
157187
}
158188

159-
func decodeReturnType() throws -> Any.Type? {
160-
self.returnType
189+
public func decodeReturnType() throws -> Any.Type? {
190+
print(" > decode return type: \(returnType.map { String(describing: $0) } ?? "nil")")
191+
return self.returnType
161192
}
162193
}
163194

0 commit comments

Comments
 (0)