Skip to content

Commit 5f09745

Browse files
committed
Make CheckedContinuation thread-safe and resilient.
1 parent c151f4b commit 5f09745

File tree

1 file changed

+99
-37
lines changed

1 file changed

+99
-37
lines changed

stdlib/public/Concurrency/CheckedContinuation.swift

Lines changed: 99 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,88 @@ import Swift
33
@_silgen_name("swift_continuation_logFailedCheck")
44
internal func logFailedCheck(_ message: UnsafeRawPointer)
55

6+
/// Implementation class that holds the `UnsafeContinuation` instance for
7+
/// a `CheckedContinuation`.
8+
internal final class CheckedContinuationCanary {
9+
// The instance state is stored in tail-allocated raw memory, so that
10+
// we can atomically check the continuation state.
11+
12+
private init() { fatalError("must use create") }
13+
14+
private static func _create(continuation: UnsafeRawPointer, function: String)
15+
-> Self {
16+
let instance = Builtin.allocWithTailElems_1(self,
17+
1._builtinWordValue,
18+
(UnsafeRawPointer?, String).self)
19+
20+
instance._continuationPtr.initialize(to: continuation)
21+
instance._functionPtr.initialize(to: function)
22+
return instance
23+
}
24+
25+
private var _continuationPtr: UnsafeMutablePointer<UnsafeRawPointer?> {
26+
return UnsafeMutablePointer<UnsafeRawPointer?>(
27+
Builtin.projectTailElems(self, (UnsafeRawPointer?, String).self))
28+
}
29+
private var _functionPtr: UnsafeMutablePointer<String> {
30+
let tailPtr = UnsafeMutableRawPointer(
31+
Builtin.projectTailElems(self, (UnsafeRawPointer?, String).self))
32+
33+
let functionPtr = tailPtr
34+
+ MemoryLayout<(UnsafeRawPointer?, String)>.offset(of: \(UnsafeRawPointer?, String).1)!
35+
36+
return functionPtr.assumingMemoryBound(to: String.self)
37+
}
38+
39+
internal static func create<T>(continuation: UnsafeContinuation<T>,
40+
function: String) -> Self {
41+
return _create(
42+
continuation: unsafeBitCast(continuation, to: UnsafeRawPointer.self),
43+
function: function)
44+
}
45+
46+
internal static func create<T>(continuation: UnsafeThrowingContinuation<T>,
47+
function: String) -> Self {
48+
return _create(
49+
continuation: unsafeBitCast(continuation, to: UnsafeRawPointer.self),
50+
function: function)
51+
}
52+
53+
internal var function: String {
54+
return _functionPtr.pointee
55+
}
56+
57+
// Take the continuation away from the container, or return nil if it's
58+
// already been taken.
59+
private func _takeContinuation() -> UnsafeRawPointer? {
60+
// Atomically exchange the current continuation value with a null pointer.
61+
let rawContinuationPtr = unsafeBitCast(_continuationPtr,
62+
to: Builtin.RawPointer.self)
63+
let rawOld = Builtin.atomicrmw_xchg_seqcst_Word(rawContinuationPtr,
64+
0._builtinWordValue)
65+
66+
return unsafeBitCast(rawOld, to: UnsafeRawPointer?.self)
67+
}
68+
69+
internal func takeContinuation<T>() -> UnsafeContinuation<T>? {
70+
return unsafeBitCast(_takeContinuation(),
71+
to: UnsafeContinuation<T>.self)
72+
}
73+
internal func takeThrowingContinuation<T>() -> UnsafeThrowingContinuation<T>? {
74+
return unsafeBitCast(_takeContinuation(),
75+
to: UnsafeThrowingContinuation<T>.self)
76+
}
77+
78+
deinit {
79+
_functionPtr.deinitialize(count: 1)
80+
// Log if the continuation was never consumed before the instance was
81+
// destructed.
82+
if _continuationPtr.pointee != nil {
83+
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(function) leaked its continuation!\n")
84+
}
85+
}
86+
}
87+
688
/// A wrapper class for `UnsafeContinuation` that logs misuses of the
789
/// continuation, logging a message if the continuation is resumed
890
/// multiple times, or if an object is destroyed without its continuation
@@ -27,9 +109,8 @@ internal func logFailedCheck(_ message: UnsafeRawPointer)
27109
/// Changing a call of `withUnsafeContinuation` into a call of
28110
/// `withCheckedContinuation` should be enough to obtain the extra checking
29111
/// without further source modification in most circumstances.
30-
public final class CheckedContinuation<T> {
31-
var continuation: UnsafeContinuation<T>?
32-
var function: String
112+
public struct CheckedContinuation<T> {
113+
let canary: CheckedContinuationCanary
33114

34115
/// Initialize a `CheckedContinuation` wrapper around an
35116
/// `UnsafeContinuation`.
@@ -46,8 +127,9 @@ public final class CheckedContinuation<T> {
46127
/// source for the continuation, used to identify the continuation in
47128
/// runtime diagnostics related to misuse of this continuation.
48129
public init(continuation: UnsafeContinuation<T>, function: String = #function) {
49-
self.continuation = continuation
50-
self.function = function
130+
canary = CheckedContinuationCanary.create(
131+
continuation: continuation,
132+
function: function)
51133
}
52134

53135
/// Resume the task awaiting the continuation by having it return normally
@@ -57,19 +139,10 @@ public final class CheckedContinuation<T> {
57139
/// already been resumed through this object, then the attempt to resume
58140
/// the continuation again will be logged, but otherwise have no effect.
59141
public func resume(returning x: __owned T) {
60-
if let c = continuation {
142+
if let c: UnsafeContinuation<T> = canary.takeContinuation() {
61143
c.resume(returning: x)
62-
// Clear out the continuation so we don't try to resume again
63-
continuation = nil
64144
} else {
65-
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(function) tried to resume its continuation more than once, returning \(x)!\n")
66-
}
67-
}
68-
69-
/// Log if the object is deallocated before its continuation is resumed.
70-
deinit {
71-
if continuation != nil {
72-
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(function) leaked its continuation!\n")
145+
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(canary.function) tried to resume its continuation more than once, returning \(x)!\n")
73146
}
74147
}
75148
}
@@ -107,9 +180,8 @@ public func withCheckedContinuation<T>(
107180
/// Changing a call of `withUnsafeThrowingContinuation` into a call of
108181
/// `withCheckedThrowingContinuation` should be enough to obtain the extra checking
109182
/// without further source modification in most circumstances.
110-
public final class CheckedThrowingContinuation<T> {
111-
var continuation: UnsafeThrowingContinuation<T>?
112-
var function: String
183+
public struct CheckedThrowingContinuation<T> {
184+
let canary: CheckedContinuationCanary
113185

114186
/// Initialize a `CheckedThrowingContinuation` wrapper around an
115187
/// `UnsafeThrowingContinuation`.
@@ -126,8 +198,9 @@ public final class CheckedThrowingContinuation<T> {
126198
/// source for the continuation, used to identify the continuation in
127199
/// runtime diagnostics related to misuse of this continuation.
128200
public init(continuation: UnsafeThrowingContinuation<T>, function: String = #function) {
129-
self.continuation = continuation
130-
self.function = function
201+
canary = CheckedContinuationCanary.create(
202+
continuation: continuation,
203+
function: function)
131204
}
132205

133206
/// Resume the task awaiting the continuation by having it return normally
@@ -138,12 +211,10 @@ public final class CheckedThrowingContinuation<T> {
138211
/// or by `resume(throwing:)`, then the attempt to resume
139212
/// the continuation again will be logged, but otherwise have no effect.
140213
public func resume(returning x: __owned T) {
141-
if let c = continuation {
214+
if let c: UnsafeThrowingContinuation<T> = canary.takeThrowingContinuation() {
142215
c.resume(returning: x)
143-
// Clear out the continuation so we don't try to resume again
144-
continuation = nil
145216
} else {
146-
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(function) tried to resume its continuation more than once, returning \(x)!\n")
217+
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(canary.function) tried to resume its continuation more than once, returning \(x)!\n")
147218
}
148219
}
149220

@@ -155,19 +226,10 @@ public final class CheckedThrowingContinuation<T> {
155226
/// or by `resume(throwing:)`, then the attempt to resume
156227
/// the continuation again will be logged, but otherwise have no effect.
157228
public func resume(throwing x: __owned Error) {
158-
if let c = continuation {
229+
if let c: UnsafeThrowingContinuation<T> = canary.takeThrowingContinuation() {
159230
c.resume(throwing: x)
160-
// Clear out the continuation so we don't try to resume again
161-
continuation = nil
162231
} else {
163-
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(function) tried to resume its continuation more than once, throwing \(x)!\n")
164-
}
165-
}
166-
167-
/// Log if the object is deallocated before its continuation is resumed.
168-
deinit {
169-
if continuation != nil {
170-
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(function) leaked its continuation!\n")
232+
logFailedCheck("SWIFT TASK CONTINUATION MISUSE: \(canary.function) tried to resume its continuation more than once, throwing \(x)!\n")
171233
}
172234
}
173235
}
@@ -176,7 +238,7 @@ public func withCheckedThrowingContinuation<T>(
176238
function: String = #function,
177239
_ body: (CheckedThrowingContinuation<T>) -> Void
178240
) async throws -> T {
179-
return await try withUnsafeThrowingContinuation {
241+
return try await withUnsafeThrowingContinuation {
180242
body(CheckedThrowingContinuation(continuation: $0, function: function))
181243
}
182244
}

0 commit comments

Comments
 (0)