Skip to content

Commit d20b8d2

Browse files
Merge pull request #65274 from nate-chandler/rdar107275872
[Concurrency] Nest stack traffic in withValue.
2 parents cc0f42f + 4fe988b commit d20b8d2

File tree

3 files changed

+92
-18
lines changed

3 files changed

+92
-18
lines changed

stdlib/public/BackDeployConcurrency/TaskLocal.swift

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,41 @@ public final class TaskLocal<Value: Sendable>: Sendable, CustomStringConvertible
141141
@_backDeploy(before: SwiftStdlib 5.8)
142142
public func withValue<R>(_ valueDuringOperation: Value, operation: () async throws -> R,
143143
file: String = #file, line: UInt = #line) async rethrows -> R {
144+
return try await withValueImpl(valueDuringOperation, operation: operation, file: file, line: line)
145+
}
146+
147+
/// Implementation for withValue that consumes valueDuringOperation.
148+
///
149+
/// Because _taskLocalValuePush and _taskLocalValuePop involve calls to
150+
/// swift_task_alloc/swift_task_dealloc respectively unbeknownst to the
151+
/// compiler, compiler-emitted calls to swift_task_de/alloc must be avoided
152+
/// in a function that calls them.
153+
///
154+
/// A copy of valueDuringOperation is required because withValue borrows its
155+
/// argument but _taskLocalValuePush consumes its. Because
156+
/// valueDuringOperation is of generic type, its size is not generally known,
157+
/// so such a copy entails a stack allocation and a copy to that allocation.
158+
/// That stack traffic gets lowered to calls to
159+
/// swift_task_alloc/swift_task_deallloc.
160+
///
161+
/// Split the calls _taskLocalValuePush/Pop from the compiler-emitted calls
162+
/// to swift_task_de/alloc for the copy as follows:
163+
/// - withValue contains the compiler-emitted calls swift_task_de/alloc.
164+
/// - withValueImpl contains the calls to _taskLocalValuePush/Pop
165+
@inlinable
166+
@discardableResult
167+
@_unsafeInheritExecutor
168+
@available(SwiftStdlib 5.1, *) // back deploy requires we declare the availability explicitly on this method
169+
@_backDeploy(before: SwiftStdlib 5.9)
170+
internal func withValueImpl<R>(_ valueDuringOperation: __owned Value, operation: () async throws -> R,
171+
file: String = #fileID, line: UInt = #line) async rethrows -> R {
144172
// check if we're not trying to bind a value from an illegal context; this may crash
145173
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
146174

147-
_taskLocalValuePush(key: key, value: valueDuringOperation)
148-
do {
149-
let result = try await operation()
150-
_taskLocalValuePop()
151-
return result
152-
} catch {
153-
_taskLocalValuePop()
154-
throw error
155-
}
175+
_taskLocalValuePush(key: key, value: consume valueDuringOperation)
176+
defer { _taskLocalValuePop() }
177+
178+
return try await operation()
156179
}
157180

158181
/// Binds the task-local to the specific value for the duration of the

stdlib/public/Concurrency/TaskLocal.swift

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,41 @@ public final class TaskLocal<Value: Sendable>: Sendable, CustomStringConvertible
141141
@_backDeploy(before: SwiftStdlib 5.8)
142142
public func withValue<R>(_ valueDuringOperation: Value, operation: () async throws -> R,
143143
file: String = #fileID, line: UInt = #line) async rethrows -> R {
144+
return try await withValueImpl(valueDuringOperation, operation: operation, file: file, line: line)
145+
}
146+
147+
/// Implementation for withValue that consumes valueDuringOperation.
148+
///
149+
/// Because _taskLocalValuePush and _taskLocalValuePop involve calls to
150+
/// swift_task_alloc/swift_task_dealloc respectively unbeknownst to the
151+
/// compiler, compiler-emitted calls to swift_task_de/alloc must be avoided
152+
/// in a function that calls them.
153+
///
154+
/// A copy of valueDuringOperation is required because withValue borrows its
155+
/// argument but _taskLocalValuePush consumes its. Because
156+
/// valueDuringOperation is of generic type, its size is not generally known,
157+
/// so such a copy entails a stack allocation and a copy to that allocation.
158+
/// That stack traffic gets lowered to calls to
159+
/// swift_task_alloc/swift_task_deallloc.
160+
///
161+
/// Split the calls _taskLocalValuePush/Pop from the compiler-emitted calls
162+
/// to swift_task_de/alloc for the copy as follows:
163+
/// - withValue contains the compiler-emitted calls swift_task_de/alloc.
164+
/// - withValueImpl contains the calls to _taskLocalValuePush/Pop
165+
@inlinable
166+
@discardableResult
167+
@_unsafeInheritExecutor
168+
@available(SwiftStdlib 5.1, *) // back deploy requires we declare the availability explicitly on this method
169+
@_backDeploy(before: SwiftStdlib 5.9)
170+
internal func withValueImpl<R>(_ valueDuringOperation: __owned Value, operation: () async throws -> R,
171+
file: String = #fileID, line: UInt = #line) async rethrows -> R {
144172
// check if we're not trying to bind a value from an illegal context; this may crash
145173
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
146174

147-
_taskLocalValuePush(key: key, value: valueDuringOperation)
148-
do {
149-
let result = try await operation()
150-
_taskLocalValuePop()
151-
return result
152-
} catch {
153-
_taskLocalValuePop()
154-
throw error
155-
}
175+
_taskLocalValuePush(key: key, value: consume valueDuringOperation)
176+
defer { _taskLocalValuePop() }
177+
178+
return try await operation()
156179
}
157180

158181
/// Binds the task-local to the specific value for the duration of the
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift -O -Xfrontend -disable-availability-checking %s -parse-as-library -module-name main -o %t/main
3+
// RUN: %target-codesign %t/main
4+
// RUN: %target-run %t/main | %FileCheck %s
5+
6+
// REQUIRES: objc_interop
7+
// REQUIRES: concurrency
8+
// REQUIRES: executable_test
9+
// REQUIRES: concurrency_runtime
10+
11+
import Foundation
12+
13+
@main struct M {
14+
@TaskLocal static var v: UUID = UUID()
15+
static func test(_ t: UUID) async {
16+
await Self.$v.withValue(t) {
17+
await Task.sleep(1)
18+
print(Self.$v.get())
19+
}
20+
}
21+
static func main() async {
22+
// CHECK: before
23+
print("before")
24+
await test(UUID())
25+
// CHECK: after
26+
print("after")
27+
}
28+
}

0 commit comments

Comments
 (0)