Skip to content

Commit ef14ff1

Browse files
committed
[Distributed] Refactor Invocation to Decoder/Encoder
1 parent 930480b commit ef14ff1

25 files changed

+447
-345
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,12 @@ IDENTIFIER(assignID)
267267
IDENTIFIER(resignID)
268268
IDENTIFIER(resolve)
269269
IDENTIFIER(remoteCall)
270-
IDENTIFIER(makeInvocation)
270+
IDENTIFIER(makeInvocationEncoder)
271271
IDENTIFIER(system)
272272
IDENTIFIER(ID)
273273
IDENTIFIER(id)
274274
IDENTIFIER(Invocation)
275+
IDENTIFIER(invocationDecoder)
275276
IDENTIFIER(_distributedActorRemoteInitialize)
276277
IDENTIFIER(_distributedActorDestroy)
277278
IDENTIFIER(__isRemoteActor)

include/swift/AST/KnownProtocols.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ PROTOCOL(Differentiable)
9999
PROTOCOL(DistributedActor)
100100
PROTOCOL(ActorIdentity)
101101
PROTOCOL(DistributedActorSystem)
102+
PROTOCOL(DistributedTargetInvocationEncoder)
103+
PROTOCOL(DistributedTargetInvocationDecoder)
102104

103105
PROTOCOL(AsyncSequence)
104106
PROTOCOL(AsyncIteratorProtocol)

lib/IRGen/GenMeta.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5340,8 +5340,10 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
53405340
case KnownProtocolKind::Differentiable:
53415341
case KnownProtocolKind::FloatingPoint:
53425342
case KnownProtocolKind::Actor:
5343-
case KnownProtocolKind::DistributedActorSystem:
53445343
case KnownProtocolKind::DistributedActor:
5344+
case KnownProtocolKind::DistributedActorSystem:
5345+
case KnownProtocolKind::DistributedTargetInvocationEncoder:
5346+
case KnownProtocolKind::DistributedTargetInvocationDecoder:
53455347
case KnownProtocolKind::ActorIdentity:
53465348
case KnownProtocolKind::SerialExecutor:
53475349
case KnownProtocolKind::Sendable:

stdlib/public/Distributed/DistributedActorSystem.swift

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@ import _Concurrency
1616
public protocol DistributedActorSystem: Sendable {
1717
/// The identity used by actors that communicate via this transport
1818
associatedtype ActorID: Sendable & Hashable & Codable // TODO: make Codable conditional here
19-
/// The specific type of the argument builder to be used for remote calls.
20-
associatedtype Invocation: DistributedTargetInvocation
19+
20+
associatedtype InvocationEncoder: DistributedTargetInvocationEncoder
21+
associatedtype InvocationDecoder: DistributedTargetInvocationDecoder
2122

2223
/// The serialization requirement that will be applied to all distributed targets used with this system.
23-
typealias SerializationRequirement = Invocation.SerializationRequirement
24+
associatedtype SerializationRequirement
25+
where SerializationRequirement == InvocationEncoder.SerializationRequirement,
26+
SerializationRequirement == InvocationDecoder.SerializationRequirement
2427

2528
// ==== ---------------------------------------------------------------------
2629
// - MARK: Resolving actors by identity
@@ -104,12 +107,12 @@ public protocol DistributedActorSystem: Sendable {
104107
/// arguments, generic substitutions, and specific error and return types
105108
/// that are associated with this specific invocation.
106109
@inlinable
107-
func makeInvocation() throws -> Invocation
110+
func makeInvocationEncoder() throws -> InvocationEncoder
108111

109112
/// Invoked by the Swift runtime when making a remote call.
110113
///
111114
/// The `arguments` are the arguments container that was previously created
112-
/// by `makeInvocation` and has been populated with all arguments.
115+
/// by `makeInvocationEncoder` and has been populated with all arguments.
113116
///
114117
/// This method should perform the actual remote function call, and await for its response.
115118
///
@@ -119,7 +122,7 @@ public protocol DistributedActorSystem: Sendable {
119122
// func remoteCall<Act, Err, Res>(
120123
// on actor: Act,
121124
// target: RemoteCallTarget,
122-
// invocation: inout Invocation,
125+
// invocation: inout InvocationDecoder,
123126
// throwing: Err.Type,
124127
// returning: Res.Type
125128
// ) async throws -> Res
@@ -149,7 +152,7 @@ extension DistributedActorSystem {
149152
public func executeDistributedTarget<Act, ResultHandler>(
150153
on actor: Act,
151154
mangledTargetName: String,
152-
invocation: inout Invocation,
155+
invocationDecoder: inout InvocationDecoder,
153156
handler: ResultHandler
154157
) async throws where Act: DistributedActor,
155158
Act.ID == ActorID,
@@ -245,41 +248,43 @@ extension DistributedActorSystem {
245248
do {
246249
// Decode the invocation and pack arguments into the h-buffer
247250
// TODO(distributed): decode the generics info
248-
var argumentDecoder = invocation.makeArgumentDecoder()
249-
var paramIdx = 0
250-
for unsafeRawArgPointer in hargs {
251-
guard paramIdx < paramCount else {
252-
throw ExecuteDistributedTargetError(
253-
message: "Unexpected attempt to decode more parameters than expected: \(paramIdx + 1)")
251+
// TODO(distributed): move this into the IRGen synthesized funcs, so we don't need hargs at all and can specialize the decodeNextArgument calls
252+
do {
253+
var paramIdx = 0
254+
for unsafeRawArgPointer in hargs {
255+
guard paramIdx < paramCount else {
256+
throw ExecuteDistributedTargetError(
257+
message: "Unexpected attempt to decode more parameters than expected: \(paramIdx + 1)")
258+
}
259+
let paramType = paramTypes[paramIdx]
260+
paramIdx += 1
261+
262+
// FIXME(distributed): func doDecode<Arg: SerializationRequirement>(_: Arg.Type) throws {
263+
// FIXME: but how would we call this...?
264+
// FIXME: > type 'Arg' constrained to non-protocol, non-class type 'Self.Invocation.SerializationRequirement'
265+
func doDecodeArgument<Arg>(_: Arg.Type) throws {
266+
let unsafeArgPointer = unsafeRawArgPointer
267+
.bindMemory(to: Arg.self, capacity: 1)
268+
try invocationDecoder.decodeNextArgument(Arg.self, into: unsafeArgPointer)
269+
}
270+
try _openExistential(paramType, do: doDecodeArgument)
254271
}
255-
let paramType = paramTypes[paramIdx]
256-
paramIdx += 1
257-
258-
// FIXME(distributed): func doDecode<Arg: SerializationRequirement>(_: Arg.Type) throws {
259-
// FIXME: but how would we call this...?
260-
// FIXME: > type 'Arg' constrained to non-protocol, non-class type 'Self.Invocation.SerializationRequirement'
261-
func doDecodeArgument<Arg>(_: Arg.Type) throws {
262-
let unsafeArgPointer = unsafeRawArgPointer
263-
.bindMemory(to: Arg.self, capacity: 1)
264-
try argumentDecoder.decodeNext(Arg.self, into: unsafeArgPointer)
265-
}
266-
try _openExistential(paramType, do: doDecodeArgument)
267272
}
268273

269-
let returnType = try invocation.decodeReturnType() ?? returnTypeFromTypeInfo
274+
let returnType = try invocationDecoder.decodeReturnType() ?? returnTypeFromTypeInfo
270275
// let errorType = try invocation.decodeErrorType() // TODO: decide how to use?
276+
271277
// Execute the target!
272278
try await _executeDistributedTarget(
273279
on: actor,
274280
mangledTargetName, UInt(mangledTargetName.count),
275-
argumentBuffer: hargs.buffer._rawValue,
281+
argumentBuffer: hargs.buffer._rawValue, // TODO(distributed): pass the invocationDecoder instead, so we can decode inside IRGen directly into the argument explosion
276282
resultBuffer: resultBuffer._rawValue
277283
)
278284

279285
func onReturn<R>(_ resultTy: R.Type) async throws {
280286
try await handler.onReturn/*<R>*/(value: resultBuffer.load(as: resultTy))
281287
}
282-
283288
try await _openExistential(returnType, do: onReturn)
284289
} catch {
285290
try await handler.onThrow(error: error)
@@ -314,7 +319,7 @@ public struct RemoteCallTarget {
314319
}
315320
}
316321

317-
/// Represents an invocation of a distributed target (method or computed property).
322+
/// Used to encode an invocation of a distributed target (method or computed property).
318323
///
319324
/// ## Forming an invocation
320325
///
@@ -345,46 +350,35 @@ public struct RemoteCallTarget {
345350
/// Note that the decoding will be provided the specific types that the sending side used to preform the call,
346351
/// so decoding can rely on simply invoking e.g. `Codable` (if that is the `SerializationRequirement`) decoding
347352
/// entry points on the provided types.
348-
@available(SwiftStdlib 5.6, *)
349-
public protocol DistributedTargetInvocation {
350-
associatedtype ArgumentDecoder: DistributedTargetInvocationArgumentDecoder
353+
public protocol DistributedTargetInvocationEncoder {
351354
associatedtype SerializationRequirement
352355

353-
// === Sending / recording -------------------------------------------------
354356
/// The arguments must be encoded order-preserving, and once `decodeGenericSubstitutions`
355357
/// is called, the substitutions must be returned in the same order in which they were recorded.
356358
mutating func recordGenericSubstitution<T>(_ type: T.Type) throws
357359

358360
// /// Ad-hoc requirement
359361
// ///
360362
// /// Record an argument of `Argument` type in this arguments storage.
361-
// mutating func recordArgument<Argument: SerializationRequirement>(argument: Argument) throws
363+
// mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws
364+
362365
mutating func recordErrorType<E: Error>(_ type: E.Type) throws
363366

364367
// /// Ad-hoc requirement
365368
// ///
366369
// /// Record the return type of the distributed method.
367370
// mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws
368-
mutating func doneRecording() throws
369371

370-
// === Receiving / decoding -------------------------------------------------
371-
mutating func decodeGenericSubstitutions() throws -> [Any.Type]
372-
373-
func makeArgumentDecoder() -> Self.ArgumentDecoder
374-
375-
mutating func decodeReturnType() throws -> Any.Type?
376-
377-
mutating func decodeErrorType() throws -> Any.Type?
372+
mutating func doneRecording() throws
378373
}
379374

380-
/// Decoding iterator produced by `DistributedTargetInvocation.argumentDecoder()`.
381-
///
382-
/// It will be called exactly `N` times where `N` is the known number of arguments
383-
/// to the target invocation.
384-
@available(SwiftStdlib 5.6, *)
385-
public protocol DistributedTargetInvocationArgumentDecoder {
375+
/// Decoder that must be provided to `executeDistributedTarget` and is used
376+
/// by the Swift runtime to decode arguments of the invocation.
377+
public protocol DistributedTargetInvocationDecoder {
386378
associatedtype SerializationRequirement
387379

380+
func decodeGenericSubstitutions() throws -> [Any.Type]
381+
388382
// /// Ad-hoc protocol requirement
389383
// ///
390384
// /// Attempt to decode the next argument from the underlying buffers into pre-allocated storage
@@ -398,14 +392,27 @@ public protocol DistributedTargetInvocationArgumentDecoder {
398392
// /// buffer for all the arguments and their expected types. The 'pointer' passed here is a pointer
399393
// /// to a "slot" in that pre-allocated buffer. That buffer will then be passed to a thunk that
400394
// /// performs the actual distributed (local) instance method invocation.
401-
// mutating func decodeNext<Argument: SerializationRequirement>(
395+
// mutating func decodeNextArgument<Argument: SerializationRequirement>(
402396
// into pointer: UnsafeMutablePointer<Argument> // pointer to our hbuffer
403397
// ) throws
404398
// FIXME(distributed): remove this since it must have the ': SerializationRequirement'
405-
mutating func decodeNext<Argument>(
399+
mutating func decodeNextArgument<Argument>(
406400
_ argumentType: Argument.Type,
407401
into pointer: UnsafeMutablePointer<Argument> // pointer to our hbuffer
408402
) throws
403+
404+
func decodeErrorType() throws -> Any.Type?
405+
406+
func decodeReturnType() throws -> Any.Type?
407+
}
408+
409+
///
410+
/// It will be called exactly `N` times where `N` is the known number of arguments
411+
/// to the target invocation.
412+
@available(SwiftStdlib 5.6, *)
413+
public protocol DistributedTargetInvocationArgumentDecoder {
414+
415+
409416
}
410417

411418
@available(SwiftStdlib 5.6, *)
@@ -427,8 +434,9 @@ public protocol DistributedActorSystemError: Error {}
427434

428435
@available(SwiftStdlib 5.6, *)
429436
public struct ExecuteDistributedTargetError: DistributedActorSystemError {
430-
private let message: String
431-
internal init(message: String) {
437+
let message: String
438+
439+
public init(message: String) {
432440
self.message = message
433441
}
434442
}

test/Distributed/Inputs/FakeDistributedActorSystems.swift

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ public struct ActorAddress: Hashable, Sendable, Codable {
3737

3838
public struct FakeActorSystem: DistributedActorSystem {
3939
public typealias ActorID = ActorAddress
40-
public typealias Invocation = FakeInvocation
40+
public typealias InvocationDecoder = FakeInvocation
41+
public typealias InvocationEncoder = FakeInvocation
4142
public typealias SerializationRequirement = Codable
4243

4344
// just so that the struct does not get optimized away entirely
@@ -70,27 +71,29 @@ public struct FakeActorSystem: DistributedActorSystem {
7071
public func resignID(_ id: ActorID) {
7172
}
7273

73-
public func makeInvocation() -> Invocation {
74+
public func makeInvocationEncoder() -> InvocationDecoder {
7475
.init()
7576
}
7677
}
7778

78-
public struct FakeInvocation: DistributedTargetInvocation {
79-
public typealias ArgumentDecoder = FakeArgumentDecoder
79+
public struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvocationDecoder {
8080
public typealias SerializationRequirement = Codable
8181

82-
public mutating func recordGenericSubstitution<T>(mangledType: T.Type) throws {}
83-
public mutating func recordArgument<Argument: SerializationRequirement>(argument: Argument) throws {}
84-
public mutating func recordReturnType<R: SerializationRequirement>(mangledType: R.Type) throws {}
85-
public mutating func recordErrorType<E: Error>(mangledType: E.Type) throws {}
82+
public mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {}
83+
public mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {}
84+
public mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
85+
public mutating func recordErrorType<E: Error>(_ type: E.Type) throws {}
8686
public mutating func doneRecording() throws {}
8787

8888
// === Receiving / decoding -------------------------------------------------
8989

90-
public mutating func decodeGenericSubstitutions() throws -> [Any.Type] { [] }
91-
public mutating func argumentDecoder() -> FakeArgumentDecoder { .init() }
92-
public mutating func decodeReturnType() throws -> Any.Type? { nil }
93-
public mutating func decodeErrorType() throws -> Any.Type? { nil }
90+
public func decodeGenericSubstitutions() throws -> [Any.Type] { [] }
91+
public mutating func decodeNextArgument<Argument>(
92+
_ argumentType: Argument.Type,
93+
into pointer: UnsafeMutablePointer<Argument> // pointer to our hbuffer
94+
) throws { /* ... */ }
95+
public func decodeReturnType() throws -> Any.Type? { nil }
96+
public func decodeErrorType() throws -> Any.Type? { nil }
9497

9598
public struct FakeArgumentDecoder: DistributedTargetInvocationArgumentDecoder {
9699
public typealias SerializationRequirement = Codable

test/Distributed/Inputs/dynamic_replacement_da_decl.swift

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ struct ActorAddress: Hashable, Sendable, Codable {
3636

3737
final class FakeActorSystem: DistributedActorSystem {
3838
typealias ActorID = ActorAddress
39-
typealias Invocation = FakeInvocation
39+
typealias InvocationDecoder = FakeInvocation
40+
typealias InvocationEncoder = FakeInvocation
4041
typealias SerializationRequirement = Codable
4142

4243
func resolve<Act>(id: ActorID, as actorType: Act.Type) throws -> Act?
@@ -59,27 +60,29 @@ final class FakeActorSystem: DistributedActorSystem {
5960
func resignID(_ id: ActorID) {
6061
}
6162

62-
func makeInvocation() -> Invocation {
63+
func makeInvocationEncoder() -> InvocationDecoder {
6364
.init()
6465
}
6566
}
6667

67-
struct FakeInvocation: DistributedTargetInvocation {
68-
typealias ArgumentDecoder = FakeArgumentDecoder
68+
struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvocationDecoder {
6969
typealias SerializationRequirement = Codable
7070

71-
mutating func recordGenericSubstitution<T>(mangledType: T.Type) throws {}
72-
mutating func recordArgument<Argument: SerializationRequirement>(argument: Argument) throws {}
73-
mutating func recordReturnType<R: SerializationRequirement>(mangledType: R.Type) throws {}
74-
mutating func recordErrorType<E: Error>(mangledType: E.Type) throws {}
71+
mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {}
72+
mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {}
73+
mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
74+
mutating func recordErrorType<E: Error>(_ type: E.Type) throws {}
7575
mutating func doneRecording() throws {}
7676

7777
// === Receiving / decoding -------------------------------------------------
7878

79-
mutating func decodeGenericSubstitutions() throws -> [Any.Type] { [] }
80-
mutating func argumentDecoder() -> FakeArgumentDecoder { .init() }
81-
mutating func decodeReturnType() throws -> Any.Type? { nil }
82-
mutating func decodeErrorType() throws -> Any.Type? { nil }
79+
func decodeGenericSubstitutions() throws -> [Any.Type] { [] }
80+
mutating func decodeNextArgument<Argument>(
81+
_ argumentType: Argument.Type,
82+
into pointer: UnsafeMutablePointer<Argument> // pointer to our hbuffer
83+
) throws { /* ... */ }
84+
func decodeReturnType() throws -> Any.Type? { nil }
85+
func decodeErrorType() throws -> Any.Type? { nil }
8386

8487
struct FakeArgumentDecoder: DistributedTargetInvocationArgumentDecoder {
8588
typealias SerializationRequirement = Codable

0 commit comments

Comments
 (0)