Skip to content

Commit df5ff42

Browse files
committed
[TaskLocals] set task local value in synchronous function
1 parent f0781b1 commit df5ff42

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

stdlib/public/Concurrency/TaskLocal.swift

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public final class TaskLocal<Value: Sendable>: CustomStringConvertible {
2828

2929

3030
public struct Access: CustomStringConvertible {
31-
let key: Builtin.RawPointer
31+
let key: Builtin.RawPointer
3232
let defaultValue: Value
3333

3434
init(key: TaskLocal<Value>, defaultValue: Value) {
@@ -103,6 +103,29 @@ extension TaskLocal {
103103
}
104104
}
105105

106+
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
107+
extension UnsafeCurrentTask {
108+
109+
/// Allows for executing a synchronous `body` while binding a task-local value
110+
/// in the current task.
111+
///
112+
/// This function MUST NOT be invoked by any other task than the current task
113+
/// represented by this object.
114+
@discardableResult
115+
public func withTaskLocal<Value: Sendable, R>(
116+
_ access: TaskLocal<Value>.Access, boundTo valueDuringBody: Value,
117+
do body: () throws -> R,
118+
file: String = #file, line: UInt = #line) rethrows -> R {
119+
// check if we're not trying to bind a value from an illegal context; this may crash
120+
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
121+
122+
_taskLocalValuePush(self._task, key: access.key, value: valueDuringBody)
123+
defer { _taskLocalValuePop(_task) }
124+
125+
return try body()
126+
}
127+
}
128+
106129
// ==== ------------------------------------------------------------------------
107130

108131
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency -parse-as-library %import-libdispatch) | %FileCheck %s
2+
3+
// REQUIRES: executable_test
4+
// REQUIRES: concurrency
5+
// REQUIRES: libdispatch
6+
7+
// rdar://76038845
8+
// UNSUPPORTED: use_os_stdlib
9+
// UNSUPPORTED: back_deployment_runtime
10+
11+
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
12+
enum TL {
13+
@TaskLocal(default: 0)
14+
static var number
15+
}
16+
17+
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
18+
@discardableResult
19+
func printTaskLocal<V>(
20+
_ key: TaskLocal<V>.Access,
21+
_ expected: V? = nil,
22+
file: String = #file, line: UInt = #line
23+
) -> V? {
24+
let value = key.get()
25+
print("\(key) (\(value)) at \(file):\(line)")
26+
if let expected = expected {
27+
assert("\(expected)" == "\(value)",
28+
"Expected [\(expected)] but found: \(value), at \(file):\(line)")
29+
}
30+
return expected
31+
}
32+
33+
// ==== ------------------------------------------------------------------------
34+
35+
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
36+
func synchronous_bind() async {
37+
38+
func synchronous() {
39+
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (1111)
40+
41+
withUnsafeCurrentTask { task in
42+
guard let task = task else {
43+
fatalError()
44+
}
45+
46+
task.withTaskLocal(TL.number, boundTo: 2222) {
47+
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (2222)
48+
}
49+
50+
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (1111)
51+
}
52+
53+
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (1111)
54+
}
55+
56+
await TL.number.withValue(1111) {
57+
synchronous()
58+
}
59+
}
60+
61+
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
62+
@main struct Main {
63+
static func main() async {
64+
await synchronous_bind()
65+
}
66+
}

0 commit comments

Comments
 (0)