Skip to content

Commit bf0681d

Browse files
committed
distributed thunk flag
1 parent ebdb061 commit bf0681d

File tree

4 files changed

+75
-31
lines changed

4 files changed

+75
-31
lines changed

include/swift/AST/Expr.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,13 @@ class alignas(8) Expr : public ASTAllocated<Expr> {
307307
NumCaptures : 32
308308
);
309309

310-
SWIFT_INLINE_BITFIELD(ApplyExpr, Expr, 1+1+1+1+1,
310+
SWIFT_INLINE_BITFIELD(ApplyExpr, Expr, 1+1+1+1+1+1,
311311
ThrowsIsSet : 1,
312312
Throws : 1,
313313
ImplicitlyAsync : 1,
314314
ImplicitlyThrows : 1,
315-
NoAsync : 1
315+
NoAsync : 1,
316+
ShouldApplyDistributedThunk : 1
316317
);
317318

318319
SWIFT_INLINE_BITFIELD_EMPTY(CallExpr, ApplyExpr);
@@ -4423,7 +4424,12 @@ class ApplyExpr : public Expr {
44234424

44244425
/// Informs IRGen to that this expression should be applied as its distributed
44254426
/// thunk, rather than invoking the function directly.
4426-
bool shouldApplyDistributedThunk() const;
4427+
bool shouldApplyDistributedThunk() const {
4428+
return Bits.ApplyExpr.ShouldApplyDistributedThunk;
4429+
}
4430+
void setShouldApplyDistributedThunk(bool flag) {
4431+
Bits.ApplyExpr.ShouldApplyDistributedThunk = flag;
4432+
}
44274433

44284434
ValueDecl *getCalledValue() const;
44294435

lib/AST/Expr.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,22 +1068,6 @@ Type OverloadSetRefExpr::getBaseType() const {
10681068
llvm_unreachable("Unhandled overloaded set reference expression");
10691069
}
10701070

1071-
bool ApplyExpr::shouldApplyDistributedThunk() const {
1072-
// only a distributed decl has a chance of needing to be invoked as a thunk
1073-
auto func = dyn_cast<AbstractFunctionDecl>(getCalledValue());
1074-
if (!func || !func->isDistributed())
1075-
return false;
1076-
1077-
if (implicitlyThrows())
1078-
return true;
1079-
1080-
auto isEffectivelyAsync = func->hasAsync() || isImplicitlyAsync();
1081-
if (func->hasThrows() && isEffectivelyAsync)
1082-
return true;
1083-
1084-
return false;
1085-
}
1086-
10871071
bool OverloadSetRefExpr::hasBaseObject() const {
10881072
if (Type BaseTy = getBaseType())
10891073
return !BaseTy->is<AnyMetatypeType>();

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,14 +1003,16 @@ namespace {
10031003
///
10041004
/// and we reach up to mark the CallExpr.
10051005
void markNearestCallAsImplicitly(
1006-
Optional<ImplicitActorHopTarget> setAsync, bool setThrows = false) {
1006+
Optional<ImplicitActorHopTarget> setAsync,
1007+
bool setThrows = false, bool setDistributedThunk = false) {
10071008
assert(applyStack.size() > 0 && "not contained within an Apply?");
10081009

10091010
const auto End = applyStack.rend();
10101011
for (auto I = applyStack.rbegin(); I != End; ++I)
10111012
if (auto call = dyn_cast<CallExpr>(*I)) {
10121013
if (setAsync) call->setImplicitlyAsync(*setAsync);
10131014
if (setThrows) call->setImplicitlyThrows(true);
1015+
if (setDistributedThunk) call->setShouldApplyDistributedThunk(true);
10141016
return;
10151017
}
10161018
llvm_unreachable("expected a CallExpr in applyStack!");
@@ -1667,7 +1669,6 @@ namespace {
16671669
ThrowsMarkingResult tryMarkImplicitlyThrows(SourceLoc declLoc,
16681670
ConcreteDeclRef concDeclRef,
16691671
Expr* context) {
1670-
16711672
ValueDecl *decl = concDeclRef.getDecl();
16721673
ThrowsMarkingResult result = ThrowsMarkingResult::NotFound;
16731674

@@ -1869,8 +1870,11 @@ namespace {
18691870
}
18701871

18711872
switch (contextIsolation) {
1872-
case ActorIsolation::ActorInstance:
1873-
case ActorIsolation::DistributedActorInstance: {
1873+
case ActorIsolation::DistributedActorInstance:
1874+
markNearestCallAsImplicitly(/*setAsync*/None, /*setThrows*/false,
1875+
/*setDistributedThunk*/true);
1876+
LLVM_FALLTHROUGH;
1877+
case ActorIsolation::ActorInstance: {
18741878
auto result = tryMarkImplicitlyAsync(
18751879
loc, valueRef, context,
18761880
ImplicitActorHopTarget::forGlobalActor(globalActor));
@@ -2278,6 +2282,8 @@ namespace {
22782282
tryMarkImplicitlyAsync(memberLoc, memberRef, context,
22792283
ImplicitActorHopTarget::forInstanceSelf());
22802284
tryMarkImplicitlyThrows(memberLoc, memberRef, context);
2285+
markNearestCallAsImplicitly(/*setAsync*/None, /*setThrows*/false,
2286+
/*setDistributedThunk*/true);
22812287

22822288
} else {
22832289
// neither static or distributed, apply full distributed isolation
@@ -2363,8 +2369,7 @@ namespace {
23632369
}
23642370

23652371
// It wasn't a distributed func, so ban the access
2366-
// TODO(distributed): handle subscripts here too
2367-
if (auto var = dyn_cast<VarDecl>(member)) {
2372+
if (isPropOrSubscript(member)) {
23682373
ctx.Diags.diagnose(
23692374
memberLoc, diag::distributed_actor_isolated_non_self_reference,
23702375
member->getDescriptiveKind(), member->getName());
@@ -2380,6 +2385,8 @@ namespace {
23802385
if (isolation.getActorType()->isDistributedActor() &&
23812386
!isolatedActor.isPotentiallyIsolated) {
23822387
tryMarkImplicitlyThrows(memberLoc, memberRef, context);
2388+
markNearestCallAsImplicitly(/*setAsync*/None, /*setThrows*/false,
2389+
/*setDistributedThunk*/true);
23832390
}
23842391

23852392
if (implicitAsyncResult == AsyncMarkingResult::FoundAsync)

test/Distributed/Runtime/distributed_actor_remote_functions.swift

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-distributed -parse-as-library) | %FileCheck %s --dump-input=always
1+
// RUN: %target-run-simple-swift(-Xfrontend -disable-availability-checking -Xfrontend -enable-experimental-distributed -parse-as-library) | %FileCheck %s --dump-input=always
22

33
// REQUIRES: executable_test
44
// REQUIRES: concurrency
@@ -41,6 +41,31 @@ distributed actor SomeSpecificDistributedActor {
4141
"local(\(#function))"
4242
}
4343

44+
distributed func callTaskSelf_inner() async throws -> String {
45+
"local(\(#function))"
46+
}
47+
distributed func callTaskSelf() async -> String {
48+
do {
49+
return try await Task {
50+
let called = try await callTaskSelf_inner() // shouldn't use the distributed thunk!
51+
return "local(\(#function)) -> \(called)"
52+
}.value
53+
} catch {
54+
return "WRONG local(\(#function)) thrown(\(error))"
55+
}
56+
}
57+
58+
distributed func callDetachedSelf() async -> String {
59+
do {
60+
return try await Task.detached {
61+
let called = try await self.callTaskSelf_inner() // shouldn't use the distributed thunk!
62+
return "local(\(#function)) -> \(called)"
63+
}.value
64+
} catch {
65+
return "WRONG local(\(#function)) thrown(\(error))"
66+
}
67+
}
68+
4469
// === errors
4570

4671
distributed func helloThrowsImplBoom() throws -> String {
@@ -74,6 +99,21 @@ extension SomeSpecificDistributedActor {
7499
"remote(\(#function))"
75100
}
76101

102+
@_dynamicReplacement(for:_remote_callTaskSelf())
103+
nonisolated func _remote_impl_callTaskSelf() async throws -> String {
104+
"remote(\(#function))"
105+
}
106+
107+
@_dynamicReplacement(for:_remote_callDetachedSelf())
108+
nonisolated func _remote_impl_callDetachedSelf() async throws -> String {
109+
"remote(\(#function))"
110+
}
111+
112+
@_dynamicReplacement(for:_remote_callTaskSelf_inner())
113+
nonisolated func _remote_impl_callTaskSelf_inner() async throws -> String {
114+
"remote(\(#function))"
115+
}
116+
77117
// === errors
78118

79119
@_dynamicReplacement(for:_remote_helloThrowsImplBoom())
@@ -101,9 +141,6 @@ func __isLocalActor(_ actor: AnyObject) -> Bool {
101141
@available(SwiftStdlib 5.5, *)
102142
struct ActorAddress: ActorIdentity {
103143
let address: String
104-
init(parse address : String) {
105-
self.address = address
106-
}
107144
}
108145

109146
@available(SwiftStdlib 5.5, *)
@@ -120,7 +157,7 @@ struct FakeTransport: ActorTransport {
120157

121158
func assignIdentity<Act>(_ actorType: Act.Type) -> AnyActorIdentity
122159
where Act: DistributedActor {
123-
.init(ActorAddress(parse: ""))
160+
.init(ActorAddress(address: ""))
124161
}
125162

126163
func actorReady<Act>(_ actor: Act) where Act: DistributedActor {
@@ -149,6 +186,12 @@ func test_remote_invoke(address: ActorAddress, transport: ActorTransport) async
149186
let h4 = try! await actor.hello()
150187
print("\(personality) - hello: \(h4)")
151188

189+
let h5 = try! await actor.callTaskSelf()
190+
print("\(personality) - callTaskSelf: \(h5)")
191+
192+
let h6 = try! await actor.callDetachedSelf()
193+
print("\(personality) - callDetachedSelf: \(h6)")
194+
152195
// error throws
153196
if __isRemoteActor(actor) {
154197
do {
@@ -180,6 +223,8 @@ func test_remote_invoke(address: ActorAddress, transport: ActorTransport) async
180223
// CHECK: local - helloAsync: local(helloAsync())
181224
// CHECK: local - helloThrows: local(helloThrows())
182225
// CHECK: local - hello: local(hello())
226+
// CHECK: local - callTaskSelf: local(callTaskSelf()) -> local(callTaskSelf_inner())
227+
// CHECK: local - callDetachedSelf: local(callDetachedSelf()) -> local(callTaskSelf_inner())
183228
// CHECK: local - helloThrowsImplBoom: Boom(whoFailed: "impl")
184229

185230
print("remote isRemote: \(__isRemoteActor(remote))")
@@ -189,6 +234,8 @@ func test_remote_invoke(address: ActorAddress, transport: ActorTransport) async
189234
// CHECK: remote - helloAsync: remote(_remote_impl_helloAsync())
190235
// CHECK: remote - helloThrows: remote(_remote_impl_helloThrows())
191236
// CHECK: remote - hello: remote(_remote_impl_hello())
237+
// CHECK: remote - callTaskSelf: remote(_remote_impl_callTaskSelf())
238+
// CHECK: remote - callDetachedSelf: remote(_remote_impl_callDetachedSelf())
192239
// CHECK: remote - helloThrowsTransportBoom: Boom(whoFailed: "transport")
193240

194241
print(local)
@@ -198,7 +245,7 @@ func test_remote_invoke(address: ActorAddress, transport: ActorTransport) async
198245
@available(SwiftStdlib 5.5, *)
199246
@main struct Main {
200247
static func main() async {
201-
let address = ActorAddress(parse: "")
248+
let address = ActorAddress(address: "")
202249
let transport = FakeTransport()
203250

204251
await test_remote_invoke(address: address, transport: transport)

0 commit comments

Comments
 (0)