Skip to content

Commit 71cc766

Browse files
authored
Correct lifetime based cancellation of AsyncStream and AsyncThrowingStream (#40087)
1 parent 127874e commit 71cc766

File tree

2 files changed

+53
-12
lines changed

2 files changed

+53
-12
lines changed

stdlib/public/Concurrency/AsyncStream.swift

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,22 @@ public struct AsyncStream<Element> {
139139
}
140140
}
141141

142-
let produce: () async -> Element?
142+
final class _Context {
143+
let storage: _Storage?
144+
let produce: () async -> Element?
145+
146+
init(storage: _Storage? = nil, produce: @escaping () async -> Element?) {
147+
self.storage = storage
148+
self.produce = produce
149+
}
150+
151+
deinit {
152+
storage?.cancel()
153+
}
154+
}
155+
156+
let context: _Context
157+
143158

144159
/// Construct a AsyncStream buffering given an Element type.
145160
///
@@ -163,7 +178,7 @@ public struct AsyncStream<Element> {
163178
_ build: (Continuation) -> Void
164179
) {
165180
let storage: _Storage = .create(limit: limit)
166-
self.init(unfolding: storage.next)
181+
context = _Context(storage: storage, produce: storage.next)
167182
build(Continuation(storage: storage))
168183
}
169184

@@ -174,7 +189,7 @@ public struct AsyncStream<Element> {
174189
) {
175190
let storage: _AsyncStreamCriticalStorage<Optional<() async -> Element?>>
176191
= .create(produce)
177-
self.produce = {
192+
context = _Context {
178193
return await withTaskCancellationHandler {
179194
guard let result = await storage.value?() else {
180195
storage.value = nil
@@ -198,7 +213,7 @@ extension AsyncStream: AsyncSequence {
198213
/// concurrently and contends with another call to next is a programmer error
199214
/// and will fatalError.
200215
public struct Iterator: AsyncIteratorProtocol {
201-
let produce: () async -> Element?
216+
let context: _Context
202217

203218
/// The next value from the AsyncStream.
204219
///
@@ -210,13 +225,13 @@ extension AsyncStream: AsyncSequence {
210225
/// awaiting a value, this will terminate the AsyncStream and next may return nil
211226
/// immediately (or will return nil on subsequent calls)
212227
public mutating func next() async -> Element? {
213-
await produce()
228+
await context.produce()
214229
}
215230
}
216231

217232
/// Construct an iterator.
218233
public func makeAsyncIterator() -> Iterator {
219-
return Iterator(produce: produce)
234+
return Iterator(context: context)
220235
}
221236
}
222237

stdlib/public/Concurrency/AsyncThrowingStream.swift

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,21 @@ public struct AsyncThrowingStream<Element, Failure: Error> {
108108
}
109109
}
110110

111-
let produce: () async throws -> Element?
111+
final class _Context {
112+
let storage: _Storage?
113+
let produce: () async throws -> Element?
114+
115+
init(storage: _Storage? = nil, produce: @escaping () async throws -> Element?) {
116+
self.storage = storage
117+
self.produce = produce
118+
}
119+
120+
deinit {
121+
storage?.cancel()
122+
}
123+
}
124+
125+
let context: _Context
112126

113127
/// Construct a AsyncThrowingStream buffering given an Element type.
114128
///
@@ -132,14 +146,26 @@ public struct AsyncThrowingStream<Element, Failure: Error> {
132146
_ build: (Continuation) -> Void
133147
) where Failure == Error {
134148
let storage: _Storage = .create(limit: limit)
135-
self.init(unfolding: storage.next)
149+
context = _Context(storage: storage, produce: storage.next)
136150
build(Continuation(storage: storage))
137151
}
138152

139153
public init(
140154
unfolding produce: @escaping () async throws -> Element?
141155
) where Failure == Error {
142-
self.produce = produce
156+
let storage: _AsyncStreamCriticalStorage<Optional<() async throws -> Element?>>
157+
= .create(produce)
158+
context = _Context {
159+
return try await withTaskCancellationHandler {
160+
guard let result = try await storage.value?() else {
161+
storage.value = nil
162+
return nil
163+
}
164+
return result
165+
} onCancel: {
166+
storage.value = nil
167+
}
168+
}
143169
}
144170
}
145171

@@ -152,16 +178,16 @@ extension AsyncThrowingStream: AsyncSequence {
152178
/// concurrently and contends with another call to next is a programmer error
153179
/// and will fatalError.
154180
public struct Iterator: AsyncIteratorProtocol {
155-
let produce: () async throws -> Element?
181+
let context: _Context
156182

157183
public mutating func next() async throws -> Element? {
158-
return try await produce()
184+
return try await context.produce()
159185
}
160186
}
161187

162188
/// Construct an iterator.
163189
public func makeAsyncIterator() -> Iterator {
164-
return Iterator(produce: produce)
190+
return Iterator(context: context)
165191
}
166192
}
167193

0 commit comments

Comments
 (0)