Skip to content

Commit f0781b1

Browse files
committed
[TaskLocals] prettier API thanks to default inits
1 parent 6f3dac1 commit f0781b1

8 files changed

+198
-152
lines changed

lib/Sema/TypeCheckPropertyWrapper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ AttachedPropertyWrappersRequest::evaluate(Evaluator &evaluator,
508508
continue;
509509
}
510510
}
511-
511+
512512
result.push_back(mutableAttr);
513513
}
514514

stdlib/public/Concurrency/TaskLocal.swift

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,32 @@ import Swift
1717
/// value for lookups in the task local storage.
1818
@propertyWrapper
1919
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
20-
public final class TaskLocal<Value>: CustomStringConvertible {
21-
var defaultValue: Value
20+
public final class TaskLocal<Value: Sendable>: CustomStringConvertible {
21+
// only reason this is ! is to store the wrapper `self` in Access so we
22+
// can use its identity as the key for lookups.
23+
private var access: Access!
2224

23-
// Note: could overload with additional parameters to support other features
24-
public init(wrappedValue: Value) {
25-
self.defaultValue = wrappedValue
25+
public init(default defaultValue: Value) {
26+
self.access = Access(key: self, defaultValue: defaultValue)
2627
}
2728

28-
public var wrappedValue: Value {
29-
get {
30-
return withUnsafeCurrentTask { task in
29+
30+
public struct Access: CustomStringConvertible {
31+
let key: Builtin.RawPointer
32+
let defaultValue: Value
33+
34+
init(key: TaskLocal<Value>, defaultValue: Value) {
35+
self.key = unsafeBitCast(key, to: Builtin.RawPointer.self)
36+
self.defaultValue = defaultValue
37+
}
38+
39+
public func get() -> Value {
40+
withUnsafeCurrentTask { task in
3141
guard let task = task else {
3242
return self.defaultValue
3343
}
3444

35-
let value = _taskLocalValueGet(
36-
task._task, key: unsafeBitCast(self, to: Builtin.RawPointer.self))
45+
let value = _taskLocalValueGet(task._task, key: key)
3746

3847
guard let rawValue = value else {
3948
return self.defaultValue
@@ -46,41 +55,54 @@ public final class TaskLocal<Value>: CustomStringConvertible {
4655
}
4756
}
4857

49-
@available(*, unavailable, message: "use ‘$myTaskLocal.withValue(_:do:)’ instead")
50-
set {
51-
fatalError("Illegal attempt to set a \(Self.self) value, use `withValue(...) { ... }` instead.")
58+
/// Execute the `body` closure
59+
@discardableResult
60+
public func withValue<R>(_ valueDuringBody: Value, do body: () async throws -> R,
61+
file: String = #file, line: UInt = #line) async rethrows -> R {
62+
// check if we're not trying to bind a value from an illegal context; this may crash
63+
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
64+
65+
// we need to escape the `_task` since the withUnsafeCurrentTask closure is not `async`.
66+
// this is safe, since we know the task will remain alive because we are running inside of it.
67+
let _task = withUnsafeCurrentTask { task in
68+
task!._task // !-safe, guaranteed to have task available inside async function
69+
}
70+
71+
_taskLocalValuePush(_task, key: key, value: valueDuringBody)
72+
defer { _taskLocalValuePop(_task) }
73+
74+
return try await body()
5275
}
53-
}
5476

55-
public var projectedValue: TaskLocal<Value> {
56-
self
77+
public var description: String {
78+
"TaskLocal<\(Value.self)>.Access"
79+
}
5780
}
5881

59-
/// Execute the `body` closure
60-
@discardableResult
61-
public func withValue<R>(_ valueDuringBody: Value, do body: () async throws -> R,
62-
file: String = #file, line: UInt = #line) async rethrows -> R {
63-
// check if we're not trying to bind a value from an illegal context; this may crash
64-
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
65-
66-
// we need to escape the `_task` since the withUnsafeCurrentTask closure is not `async`.
67-
// this is safe, since we know the task will remain alive because we are running inside of it.
68-
let _task = withUnsafeCurrentTask { task in
69-
task!._task // !-safe, guaranteed to have task available inside async function
82+
public var wrappedValue: TaskLocal<Value>.Access {
83+
get {
84+
self.access
7085
}
7186

72-
_taskLocalValuePush(_task, key: unsafeBitCast(self, to: Builtin.RawPointer.self), value: valueDuringBody)
73-
defer { _taskLocalValuePop(_task) }
74-
75-
return try await body()
87+
@available(*, unavailable, message: "use 'myTaskLocal.withValue(_:do:)' instead")
88+
set {
89+
fatalError("Illegal attempt to set a \(Self.self) value, use `withValue(...) { ... }` instead.")
90+
}
7691
}
7792

7893
public var description: String {
79-
"\(Self.self)(\(wrappedValue))"
94+
"\(Self.self)(defaultValue: \(self.access.defaultValue))"
8095
}
8196

8297
}
8398

99+
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
100+
extension TaskLocal {
101+
public convenience init<V>() where Value == Optional<V> {
102+
self.init(default: nil)
103+
}
104+
}
105+
84106
// ==== ------------------------------------------------------------------------
85107

86108
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)

test/Concurrency/Runtime/async_task_locals_async_let.swift

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010

1111
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
1212
enum TL {
13-
@TaskLocal
14-
static var number: Int = 0
13+
@TaskLocal(default: 0)
14+
static var number
1515
}
1616

1717
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
1818
@discardableResult
19-
func printTaskLocal<V, Key>(
20-
_ key: Key,
19+
func printTaskLocal<V>(
20+
_ key: TaskLocal<V>.Access,
2121
_ expected: V? = nil,
2222
file: String = #file, line: UInt = #line
23-
) -> V? where Key: TaskLocal<V> {
24-
let value = key
25-
print("\(value) at \(file):\(line)")
23+
) -> V? {
24+
let value = key.get()
25+
print("\(key) (\(value)) at \(file):\(line)")
2626
if let expected = expected {
2727
assert("\(expected)" == "\(value)",
2828
"Expected [\(expected)] but found: \(value), at \(file):\(line)")
@@ -34,14 +34,14 @@ func printTaskLocal<V, Key>(
3434

3535
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
3636
func async_let_nested() async {
37-
_ = printTaskLocal(TL.$number) // CHECK: TaskLocal<Int>(0)
38-
async let x1: () = TL.$number.withValue(2) {
39-
async let x2 = printTaskLocal(TL.$number) // CHECK: TaskLocal<Int>(2)
37+
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (0)
38+
async let x1: () = TL.number.withValue(2) {
39+
async let x2 = printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (2)
4040

4141
@Sendable
4242
func test() async {
43-
printTaskLocal(TL.$number) // CHECK: TaskLocal<Int>(2)
44-
async let x31 = printTaskLocal(TL.$number) // CHECK: TaskLocal<Int>(2)
43+
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (2)
44+
async let x31 = printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (2)
4545
_ = await x31
4646
}
4747
async let x3: () = test()
@@ -51,18 +51,18 @@ func async_let_nested() async {
5151
}
5252

5353
_ = await x1
54-
printTaskLocal(TL.$number) // CHECK: TaskLocal<Int>(0)
54+
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (0)
5555
}
5656

5757
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
5858
func async_let_nested_skip_optimization() async {
59-
async let x1: Int? = TL.$number.withValue(2) {
59+
async let x1: Int? = TL.number.withValue(2) {
6060
async let x2: Int? = { () async -> Int? in
6161
async let x3: Int? = { () async -> Int? in
6262
async let x4: Int? = { () async -> Int? in
6363
async let x5: Int? = { () async -> Int? in
64-
assert(TL.number == 2)
65-
async let xx = printTaskLocal(TL.$number) // CHECK: TaskLocal<Int>(2)
64+
assert(TL.number.get() == 2)
65+
async let xx = printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (2)
6666
return await xx
6767
}()
6868
return await x5

test/Concurrency/Runtime/async_task_locals_basic.swift

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// UNSUPPORTED: use_os_stdlib
99
// UNSUPPORTED: back_deployment_runtime
1010

11-
class StringLike: CustomStringConvertible {
11+
final class StringLike: Sendable, CustomStringConvertible {
1212
let value: String
1313
init(_ value: String) {
1414
self.value = value
@@ -20,21 +20,21 @@ class StringLike: CustomStringConvertible {
2020
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
2121
enum TL {
2222

23-
@TaskLocal
24-
static var string: String = "<undefined>"
23+
@TaskLocal(default: "<undefined>")
24+
static var string
2525

26-
@TaskLocal
27-
static var number: Int = 0
26+
@TaskLocal(default: 0)
27+
static var number
2828

29-
@TaskLocal
30-
static var never: StringLike = .init("<never>")
29+
@TaskLocal(default: StringLike("<never>"))
30+
static var never
3131

32-
@TaskLocal
33-
static var clazz: ClassTaskLocal? = nil
32+
@TaskLocal()
33+
static var clazz: TaskLocal<ClassTaskLocal?>.Access
3434
}
3535

3636
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
37-
final class ClassTaskLocal {
37+
final class ClassTaskLocal: Sendable {
3838
init() {
3939
print("clazz init \(ObjectIdentifier(self))")
4040
}
@@ -46,15 +46,15 @@ final class ClassTaskLocal {
4646

4747
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
4848
@discardableResult
49-
func printTaskLocal<V, Key>(
50-
_ key: Key,
49+
func printTaskLocal<V>(
50+
_ key: TaskLocal<V>.Access,
5151
_ expected: V? = nil,
5252
file: String = #file, line: UInt = #line
53-
) -> V? where Key: TaskLocal<V> {
54-
let value = key
55-
print("\(value) at \(file):\(line)")
53+
) -> V? {
54+
let value = key.get()
55+
print("\(key) (\(value)) at \(file):\(line)")
5656
if let expected = expected {
57-
assert("\(expected)" == "\(value.wrappedValue)",
57+
assert("\(expected)" == "\(value)",
5858
"Expected [\(expected)] but found: \(value), at \(file):\(line)")
5959
}
6060
return expected
@@ -64,20 +64,20 @@ func printTaskLocal<V, Key>(
6464

6565
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
6666
func simple() async {
67-
printTaskLocal(TL.$number) // CHECK: TaskLocal<Int>(0)
68-
await TL.$number.withValue(1) {
69-
printTaskLocal(TL.$number) // CHECK-NEXT: TaskLocal<Int>(1)
67+
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (0)
68+
await TL.number.withValue(1) {
69+
printTaskLocal(TL.number) // CHECK-NEXT: TaskLocal<Int>.Access (1)
7070
}
7171
}
7272

7373
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
7474
func simple_deinit() async {
75-
await TL.$clazz.withValue(ClassTaskLocal()) {
75+
await TL.clazz.withValue(ClassTaskLocal()) {
7676
// CHECK: clazz init [[C:.*]]
77-
printTaskLocal(TL.$clazz) // CHECK: TaskLocal<Optional<ClassTaskLocal>>(Optional(main.ClassTaskLocal))
77+
printTaskLocal(TL.clazz) // CHECK: TaskLocal<Optional<ClassTaskLocal>>.Access (Optional(main.ClassTaskLocal))
7878
}
7979
// CHECK: clazz deinit [[C]]
80-
printTaskLocal(TL.$clazz) // CHECK: TaskLocal<Optional<ClassTaskLocal>>(nil)
80+
printTaskLocal(TL.clazz) // CHECK: TaskLocal<Optional<ClassTaskLocal>>.Access (nil)
8181
}
8282

8383
struct Boom: Error {
@@ -86,7 +86,7 @@ struct Boom: Error {
8686
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
8787
func simple_throw() async {
8888
do {
89-
try await TL.$clazz.withValue(ClassTaskLocal()) {
89+
try await TL.clazz.withValue(ClassTaskLocal()) {
9090
throw Boom(value: "oh no!")
9191
}
9292
} catch {
@@ -97,59 +97,59 @@ func simple_throw() async {
9797

9898
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
9999
func nested() async {
100-
printTaskLocal(TL.$string) // CHECK: TaskLocal<String>(<undefined>)
101-
await TL.$string.withValue("hello") {
102-
printTaskLocal(TL.$number) // CHECK-NEXT: TaskLocal<Int>(0)
103-
printTaskLocal(TL.$string)// CHECK-NEXT: TaskLocal<String>(hello)
104-
await TL.$number.withValue(2) {
105-
printTaskLocal(TL.$number) // CHECK-NEXT: TaskLocal<Int>(2)
106-
printTaskLocal(TL.$string, "hello") // CHECK: TaskLocal<String>(hello)
100+
printTaskLocal(TL.string) // CHECK: TaskLocal<String>.Access (<undefined>)
101+
await TL.string.withValue("hello") {
102+
printTaskLocal(TL.number) // CHECK-NEXT: TaskLocal<Int>.Access (0)
103+
printTaskLocal(TL.string)// CHECK-NEXT: TaskLocal<String>.Access (hello)
104+
await TL.number.withValue(2) {
105+
printTaskLocal(TL.number) // CHECK-NEXT: TaskLocal<Int>.Access (2)
106+
printTaskLocal(TL.string, "hello") // CHECK: TaskLocal<String>.Access (hello)
107107
}
108-
printTaskLocal(TL.$number) // CHECK-NEXT: TaskLocal<Int>(0)
109-
printTaskLocal(TL.$string) // CHECK-NEXT: TaskLocal<String>(hello)
108+
printTaskLocal(TL.number) // CHECK-NEXT: TaskLocal<Int>.Access (0)
109+
printTaskLocal(TL.string) // CHECK-NEXT: TaskLocal<String>.Access (hello)
110110
}
111-
printTaskLocal(TL.$number) // CHECK-NEXT: TaskLocal<Int>(0)
112-
printTaskLocal(TL.$string) // CHECK-NEXT: TaskLocal<String>(<undefined>)
111+
printTaskLocal(TL.number) // CHECK-NEXT: TaskLocal<Int>.Access (0)
112+
printTaskLocal(TL.string) // CHECK-NEXT: TaskLocal<String>.Access (<undefined>)
113113
}
114114

115115
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
116116
func nested_allContribute() async {
117-
printTaskLocal(TL.$string) // CHECK: TaskLocal<String>(<undefined>)
118-
await TL.$string.withValue("one") {
119-
printTaskLocal(TL.$string, "one")// CHECK-NEXT: TaskLocal<String>(one)
120-
await TL.$string.withValue("two") {
121-
printTaskLocal(TL.$string, "two") // CHECK-NEXT: TaskLocal<String>(two)
122-
await TL.$string.withValue("three") {
123-
printTaskLocal(TL.$string, "three") // CHECK-NEXT: TaskLocal<String>(three)
117+
printTaskLocal(TL.string) // CHECK: TaskLocal<String>.Access (<undefined>)
118+
await TL.string.withValue("one") {
119+
printTaskLocal(TL.string, "one")// CHECK-NEXT: TaskLocal<String>.Access (one)
120+
await TL.string.withValue("two") {
121+
printTaskLocal(TL.string, "two") // CHECK-NEXT: TaskLocal<String>.Access (two)
122+
await TL.string.withValue("three") {
123+
printTaskLocal(TL.string, "three") // CHECK-NEXT: TaskLocal<String>.Access (three)
124124
}
125-
printTaskLocal(TL.$string, "two") // CHECK-NEXT: TaskLocal<String>(two)
125+
printTaskLocal(TL.string, "two") // CHECK-NEXT: TaskLocal<String>.Access (two)
126126
}
127-
printTaskLocal(TL.$string, "one")// CHECK-NEXT: TaskLocal<String>(one)
127+
printTaskLocal(TL.string, "one")// CHECK-NEXT: TaskLocal<String>.Access (one)
128128
}
129-
printTaskLocal(TL.$string) // CHECK-NEXT: TaskLocal<String>(<undefined>)
129+
printTaskLocal(TL.string) // CHECK-NEXT: TaskLocal<String>.Access (<undefined>)
130130
}
131131

132132
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
133133
func nested_3_onlyTopContributes() async {
134-
printTaskLocal(TL.$string) // CHECK: TaskLocal<String>(<undefined>)
135-
await TL.$string.withValue("one") {
136-
printTaskLocal(TL.$string)// CHECK-NEXT: TaskLocal<String>(one)
137-
await TL.$number.withValue(2) {
138-
printTaskLocal(TL.$string) // CHECK-NEXT: TaskLocal<String>(one)
139-
await TL.$number.withValue(3) {
140-
printTaskLocal(TL.$string) // CHECK-NEXT: TaskLocal<String>(one)
134+
printTaskLocal(TL.string) // CHECK: TaskLocal<String>.Access (<undefined>)
135+
await TL.string.withValue("one") {
136+
printTaskLocal(TL.string)// CHECK-NEXT: TaskLocal<String>.Access (one)
137+
await TL.number.withValue(2) {
138+
printTaskLocal(TL.string) // CHECK-NEXT: TaskLocal<String>.Access (one)
139+
await TL.number.withValue(3) {
140+
printTaskLocal(TL.string) // CHECK-NEXT: TaskLocal<String>.Access (one)
141141
}
142-
printTaskLocal(TL.$string) // CHECK-NEXT: TaskLocal<String>(one)
142+
printTaskLocal(TL.string) // CHECK-NEXT: TaskLocal<String>.Access (one)
143143
}
144-
printTaskLocal(TL.$string)// CHECK-NEXT: TaskLocal<String>(one)
144+
printTaskLocal(TL.string)// CHECK-NEXT: TaskLocal<String>.Access (one)
145145
}
146-
printTaskLocal(TL.$string) // CHECK-NEXT: TaskLocal<String>(<undefined>)
146+
printTaskLocal(TL.string) // CHECK-NEXT: TaskLocal<String>.Access (<undefined>)
147147
}
148148

149149
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
150150
func withLocal_body_mustNotEscape() async {
151151
var something = "Nice"
152-
await TL.$string.withValue("xxx") {
152+
await TL.string.withValue("xxx") {
153153
something = "very nice"
154154
}
155155
_ = something // silence not used warning

0 commit comments

Comments
 (0)