Skip to content

Commit 460a030

Browse files
authored
[Distributed] Avoid infinite recursion in distributed thunk on protocol extensions (#73032)
1 parent 5609657 commit 460a030

File tree

5 files changed

+127
-24
lines changed

5 files changed

+127
-24
lines changed

lib/SILGen/SILGenApply.cpp

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -682,20 +682,9 @@ class Callee {
682682
}
683683
case Kind::WitnessMethod: {
684684
if (auto func = constant->getFuncDecl()) {
685-
auto isDistributedFuncOrAccessor =
686-
func->isDistributed();
687-
if (auto acc = dyn_cast<AccessorDecl>(func)) {
688-
isDistributedFuncOrAccessor =
689-
acc->getStorage()->isDistributed();
690-
}
691-
if (isa<ProtocolDecl>(func->getDeclContext()) && isDistributedFuncOrAccessor) {
692-
// If we're calling cross-actor, we must always use a distributed thunk
693-
if (!isSameActorIsolated(func, SGF.FunctionDC)) {
694-
// the protocol witness must always be a distributed thunk, as we
695-
// may be crossing a remote boundary here.
696-
auto thunk = func->getDistributedThunk();
697-
constant = SILDeclRef(thunk).asDistributed();
698-
}
685+
if (SGF.shouldReplaceConstantForApplyWithDistributedThunk(func)) {
686+
auto thunk = func->getDistributedThunk();
687+
constant = SILDeclRef(thunk).asDistributed();
699688
}
700689
}
701690

@@ -783,12 +772,8 @@ class Callee {
783772
}
784773
case Kind::WitnessMethod: {
785774
if (auto func = constant->getFuncDecl()) {
786-
if (func->isDistributed() && isa<ProtocolDecl>(func->getDeclContext())) {
787-
// If we're calling cross-actor, we must always use a distributed thunk
788-
if (!isSameActorIsolated(func, SGF.FunctionDC)) {
789-
/// We must adjust the constant to use a distributed thunk.
790-
constant = constant->asDistributed();
791-
}
775+
if (SGF.shouldReplaceConstantForApplyWithDistributedThunk(func)) {
776+
constant = constant->asDistributed();
792777
}
793778
}
794779

lib/SILGen/SILGenDistributed.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,34 @@ void InitializeDistActorIdentity::dump(SILGenFunction &) const {
274274
#endif
275275
}
276276

277+
bool SILGenFunction::shouldReplaceConstantForApplyWithDistributedThunk(
278+
FuncDecl *func) const {
279+
auto isDistributedFuncOrAccessor =
280+
func->isDistributed();
281+
if (auto acc = dyn_cast<AccessorDecl>(func)) {
282+
isDistributedFuncOrAccessor =
283+
acc->getStorage()->isDistributed();
284+
}
285+
286+
if (!isDistributedFuncOrAccessor)
287+
return false;
288+
289+
// If we are inside a distributed thunk, we want to call the "real" method,
290+
// in order to avoid infinitely recursively calling the thunk from itself.
291+
if (F.isDistributed() && F.isThunk())
292+
return false;
293+
294+
// If caller and called func are isolated to the same (distributed) actor,
295+
// (i.e. we are "inside the distributed actor"), there is no need to call
296+
// the thunk.
297+
if (isSameActorIsolated(func, FunctionDC))
298+
return false;
299+
300+
// In all other situations, we may have to replace the called function,
301+
// depending on isolation (to be checked in SILGenApply).
302+
return true;
303+
}
304+
277305
void SILGenFunction::emitDistributedActorImplicitPropertyInits(
278306
ConstructorDecl *ctor, ManagedValue selfArg) {
279307
// Only designated initializers should perform this initialization.

lib/SILGen/SILGenFunction.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,6 +2452,34 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
24522452
// Distributed Actors
24532453
//===---------------------------------------------------------------------===//
24542454

2455+
/// Determine if the target `func` should be replaced with a
2456+
/// 'distributed thunk'.
2457+
///
2458+
/// This only applies to distributed functions when calls are made cross-actor
2459+
/// isolation. One notable exception is a distributed thunk calling the "real
2460+
/// underlying method", in which case (to avoid the thunk calling into itself,
2461+
/// the real method must be called).
2462+
///
2463+
/// Witness calls which may need to be replaced with a distributed thunk call
2464+
/// happen either when the target type is generic, or if we are inside an
2465+
/// extension on a protocol. This method checks if we are in a context
2466+
/// where we should be calling the distributed thunk of the `func` or not.
2467+
/// Notably, if we are inside a distributed thunk already and are trying to
2468+
/// apply distributed method calls, all those must be to the "real" method,
2469+
/// because the thunks' responsibility is to call the real method, so this
2470+
/// replacement cannot be applied (or we'd recursively keep calling the same
2471+
/// thunk via witness).
2472+
///
2473+
/// In situations which do not use a witness call, distributed methods are always
2474+
/// invoked Direct, and never ClassMethod, because distributed are effectively
2475+
/// final.
2476+
///
2477+
/// \param func the target func that we are trying to "apply"
2478+
/// \return true when the function should be considered for replacement
2479+
/// with distributed thunk when applying it
2480+
bool
2481+
shouldReplaceConstantForApplyWithDistributedThunk(FuncDecl *func) const;
2482+
24552483
/// Initializes the implicit stored properties of a distributed actor that correspond to
24562484
/// its transport and identity.
24572485
void emitDistributedActorImplicitPropertyInits(

test/Distributed/Runtime/distributed_actor_func_calls_remoteCall_extension.swift

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,28 @@ import FakeDistributedActorSystems
1919

2020
typealias DefaultDistributedActorSystem = FakeRoundtripActorSystem
2121

22-
extension FakeRoundtripActorSystemDistributedActor {
22+
protocol KappaProtocol : DistributedActor where ActorSystem == FakeRoundtripActorSystem {
23+
distributed func echo(name: String) -> String
24+
}
25+
26+
extension KappaProtocol {
2327
distributed func echo(name: String) -> String {
2428
return "Echo: \(name)"
2529
}
2630
}
2731

32+
distributed actor KappaProtocolImpl: KappaProtocol {
33+
// empty, gets default impl from extension on protocol
34+
}
35+
2836
func test() async throws {
2937
let system = DefaultDistributedActorSystem()
3038

31-
let local = FakeRoundtripActorSystemDistributedActor(actorSystem: system)
32-
let ref = try FakeRoundtripActorSystemDistributedActor.resolve(id: local.id, using: system)
39+
let local = KappaProtocolImpl(actorSystem: system)
40+
let ref = try KappaProtocolImpl.resolve(id: local.id, using: system)
3341

3442
let reply = try await ref.echo(name: "Caplin")
35-
// CHECK: >> remoteCall: on:main.FakeRoundtripActorSystemDistributedActor, target:main.FakeRoundtripActorSystemDistributedActor.echo(name:), invocation:FakeInvocationEncoder(genericSubs: [], arguments: ["Caplin"], returnType: Optional(Swift.String), errorType: nil), throwing:Swift.Never, returning:Swift.String
43+
// CHECK: >> remoteCall: on:main.KappaProtocolImpl, target:main.$KappaProtocol.echo(name:), invocation:FakeInvocationEncoder(genericSubs: [main.KappaProtocolImpl], arguments: ["Caplin"], returnType: Optional(Swift.String), errorType: nil), throwing:Swift.Never, returning:Swift.String
3644

3745
// CHECK: << remoteCall return: Echo: Caplin
3846
print("reply: \(reply)")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend-emit-module -emit-module-path %t/FakeDistributedActorSystems.swiftmodule -module-name FakeDistributedActorSystems -disable-availability-checking %S/../Inputs/FakeDistributedActorSystems.swift
3+
// RUN: %target-build-swift -module-name main -Xfrontend -disable-availability-checking -j2 -parse-as-library -I %t %s %S/../Inputs/FakeDistributedActorSystems.swift -o %t/a.out
4+
// RUN: %target-codesign %t/a.out
5+
// RUN: %target-run %t/a.out | %FileCheck %s --dump-input=always
6+
7+
// REQUIRES: executable_test
8+
// REQUIRES: concurrency
9+
// REQUIRES: distributed
10+
11+
// rdar://76038845
12+
// UNSUPPORTED: use_os_stdlib
13+
// UNSUPPORTED: back_deployment_runtime
14+
15+
// UNSUPPORTED: OS=windows-msvc
16+
17+
import Distributed
18+
import FakeDistributedActorSystems
19+
20+
typealias DefaultDistributedActorSystem = FakeRoundtripActorSystem
21+
22+
protocol KappaProtocol : DistributedActor where ActorSystem == FakeRoundtripActorSystem {
23+
distributed func echo(name: String) -> String
24+
}
25+
26+
distributed actor KappaProtocolImpl: KappaProtocol {
27+
// empty, gets default impl from extension on this actor
28+
}
29+
30+
extension KappaProtocolImpl {
31+
distributed func echo(name: String) -> String {
32+
return "Echo: \(name)"
33+
}
34+
}
35+
36+
func test() async throws {
37+
let system = DefaultDistributedActorSystem()
38+
39+
let local = KappaProtocolImpl(actorSystem: system)
40+
let ref = try KappaProtocolImpl.resolve(id: local.id, using: system)
41+
42+
let reply = try await ref.echo(name: "Caplin")
43+
// CHECK: >> remoteCall: on:main.KappaProtocolImpl, target:main.KappaProtocolImpl.echo(name:), invocation:FakeInvocationEncoder(genericSubs: [], arguments: ["Caplin"], returnType: Optional(Swift.String), errorType: nil), throwing:Swift.Never, returning:Swift.String
44+
45+
// CHECK: << remoteCall return: Echo: Caplin
46+
print("reply: \(reply)")
47+
// CHECK: reply: Echo: Caplin
48+
}
49+
50+
@main struct Main {
51+
static func main() async {
52+
try! await test()
53+
}
54+
}

0 commit comments

Comments
 (0)