Skip to content

[6.0][Distributed] Avoid infinite recursion in distributed thunk on protocol extension #73033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 71 additions & 22 deletions lib/SILGen/SILGenApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,69 @@ class Callee {
return result;
}


/// Determine if the target `func` should be replaced with a
/// 'distributed thunk'.
///
/// This only applies to distributed functions when calls are made cross-actor
/// isolation. One notable exception is a distributed thunk calling the "real
/// underlying method", in which case (to avoid the thunk calling into itself,
/// the real method must be called).
///
/// Witness calls which may need to be replaced with a distributed thunk call
/// happen either when the target type is generic, or if we are inside an
/// extension on a protocol. This method checks if we are in a context
/// where we should be calling the distributed thunk of the `func` or not.
/// Notably, if we are inside a distributed thunk already and are trying to
/// apply distributed method calls, all those must be to the "real" method,
/// because the thunks' responsibility is to call the real method, so this
/// replacement cannot be applied (or we'd recursively keep calling the same
/// thunk via witness).
///
/// In situations which do not use a witness call, distributed methods are always
/// invoked Direct, and never ClassMethod, because distributed are effectively
/// final.
///
/// \param constant the target that we want to dispatch to
/// \return true when the function should be considered for replacement
/// with distributed thunk when applying it
bool shouldDispatchWitnessViaDistributedThunk(
SILGenFunction &SGF,
std::optional<SILDeclRef> constant
) const {
if (!constant.has_value())
return false;

auto func = dyn_cast<FuncDecl>(constant->getDecl());
if (!func)
return false;

auto isDistributedFuncOrAccessor =
func->isDistributed();
if (auto acc = dyn_cast<AccessorDecl>(func)) {
isDistributedFuncOrAccessor =
acc->getStorage()->isDistributed();
}

if (!isDistributedFuncOrAccessor)
return false;

// If we are inside a distributed thunk, we want to call the "real" method,
// in order to avoid infinitely recursively calling the thunk from itself.
if (SGF.F.isDistributed() && SGF.F.isThunk())
return false;

// If caller and called func are isolated to the same (distributed) actor,
// (i.e. we are "inside the distributed actor"), there is no need to call
// the thunk.
if (isSameActorIsolated(func, SGF.FunctionDC))
return false;

// In all other situations, we may have to replace the called function,
// depending on isolation (to be checked in SILGenApply).
return true;
}

ManagedValue getFnValue(SILGenFunction &SGF,
std::optional<ManagedValue> borrowedSelf) const & {
std::optional<SILDeclRef> constant = std::nullopt;
Expand Down Expand Up @@ -681,22 +744,12 @@ class Callee {
return fn;
}
case Kind::WitnessMethod: {
if (auto func = constant->getFuncDecl()) {
auto isDistributedFuncOrAccessor =
func->isDistributed();
if (auto acc = dyn_cast<AccessorDecl>(func)) {
isDistributedFuncOrAccessor =
acc->getStorage()->isDistributed();
}
if (isa<ProtocolDecl>(func->getDeclContext()) && isDistributedFuncOrAccessor) {
// If we're calling cross-actor, we must always use a distributed thunk
if (!isSameActorIsolated(func, SGF.FunctionDC)) {
// the protocol witness must always be a distributed thunk, as we
// may be crossing a remote boundary here.
auto thunk = func->getDistributedThunk();
constant = SILDeclRef(thunk).asDistributed();
}
}
if (shouldDispatchWitnessViaDistributedThunk(SGF, constant)) {
auto func = dyn_cast<FuncDecl>(constant->getDecl());
assert(func); // guaranteed be non-null if shouldDispatch returned true

auto thunk = func->getDistributedThunk();
constant = SILDeclRef(thunk).asDistributed();
}

auto constantInfo =
Expand Down Expand Up @@ -783,12 +836,8 @@ class Callee {
}
case Kind::WitnessMethod: {
if (auto func = constant->getFuncDecl()) {
if (func->isDistributed() && isa<ProtocolDecl>(func->getDeclContext())) {
// If we're calling cross-actor, we must always use a distributed thunk
if (!isSameActorIsolated(func, SGF.FunctionDC)) {
/// We must adjust the constant to use a distributed thunk.
constant = constant->asDistributed();
}
if (shouldDispatchWitnessViaDistributedThunk(SGF, constant)) {
constant = constant->asDistributed();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,28 @@ import FakeDistributedActorSystems

typealias DefaultDistributedActorSystem = FakeRoundtripActorSystem

extension FakeRoundtripActorSystemDistributedActor {
protocol KappaProtocol : DistributedActor where ActorSystem == FakeRoundtripActorSystem {
distributed func echo(name: String) -> String
}

extension KappaProtocol {
distributed func echo(name: String) -> String {
return "Echo: \(name)"
}
}

distributed actor KappaProtocolImpl: KappaProtocol {
// empty, gets default impl from extension on protocol
}

func test() async throws {
let system = DefaultDistributedActorSystem()

let local = FakeRoundtripActorSystemDistributedActor(actorSystem: system)
let ref = try FakeRoundtripActorSystemDistributedActor.resolve(id: local.id, using: system)
let local = KappaProtocolImpl(actorSystem: system)
let ref = try KappaProtocolImpl.resolve(id: local.id, using: system)

let reply = try await ref.echo(name: "Caplin")
// CHECK: >> remoteCall: on:main.FakeRoundtripActorSystemDistributedActor, target:main.FakeRoundtripActorSystemDistributedActor.echo(name:), invocation:FakeInvocationEncoder(genericSubs: [], arguments: ["Caplin"], returnType: Optional(Swift.String), errorType: nil), throwing:Swift.Never, returning:Swift.String
// CHECK: >> remoteCall: on:main.KappaProtocolImpl, target:main.$KappaProtocol.echo(name:), invocation:FakeInvocationEncoder(genericSubs: [main.KappaProtocolImpl], arguments: ["Caplin"], returnType: Optional(Swift.String), errorType: nil), throwing:Swift.Never, returning:Swift.String

// CHECK: << remoteCall return: Echo: Caplin
print("reply: \(reply)")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend-emit-module -emit-module-path %t/FakeDistributedActorSystems.swiftmodule -module-name FakeDistributedActorSystems -disable-availability-checking %S/../Inputs/FakeDistributedActorSystems.swift
// RUN: %target-build-swift -module-name main -Xfrontend -disable-availability-checking -j2 -parse-as-library -I %t %s %S/../Inputs/FakeDistributedActorSystems.swift -o %t/a.out
// RUN: %target-codesign %t/a.out
// RUN: %target-run %t/a.out | %FileCheck %s --dump-input=always

// REQUIRES: executable_test
// REQUIRES: concurrency
// REQUIRES: distributed

// rdar://76038845
// UNSUPPORTED: use_os_stdlib
// UNSUPPORTED: back_deployment_runtime

// UNSUPPORTED: OS=windows-msvc

import Distributed
import FakeDistributedActorSystems

typealias DefaultDistributedActorSystem = FakeRoundtripActorSystem

protocol KappaProtocol : DistributedActor where ActorSystem == FakeRoundtripActorSystem {
distributed func echo(name: String) -> String
}

distributed actor KappaProtocolImpl: KappaProtocol {
// empty, gets default impl from extension on this actor
}

extension KappaProtocolImpl {
distributed func echo(name: String) -> String {
return "Echo: \(name)"
}
}

func test() async throws {
let system = DefaultDistributedActorSystem()

let local = KappaProtocolImpl(actorSystem: system)
let ref = try KappaProtocolImpl.resolve(id: local.id, using: system)

let reply = try await ref.echo(name: "Caplin")
// CHECK: >> remoteCall: on:main.KappaProtocolImpl, target:main.KappaProtocolImpl.echo(name:), invocation:FakeInvocationEncoder(genericSubs: [], arguments: ["Caplin"], returnType: Optional(Swift.String), errorType: nil), throwing:Swift.Never, returning:Swift.String

// CHECK: << remoteCall return: Echo: Caplin
print("reply: \(reply)")
// CHECK: reply: Echo: Caplin
}

@main struct Main {
static func main() async {
try! await test()
}
}