Skip to content

Fix task cancellation leak #1418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions Sources/ComposableArchitecture/Effects/Cancellation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -196,27 +196,24 @@ public func withTaskCancellation<T: Sendable>(
cancelInFlight: Bool = false,
operation: @Sendable @escaping () async throws -> T
) async rethrows -> T {
let task = { () -> Task<T, Error> in
cancellablesLock.lock()
let id = CancelToken(id: id)
let id = CancelToken(id: id)
let (cancellable, task) = cancellablesLock.sync { () -> (AnyCancellable, Task<T, Error>) in
if cancelInFlight {
cancellationCancellables[id]?.forEach { $0.cancel() }
}
let task = Task { try await operation() }
var cancellable: AnyCancellable!
cancellable = AnyCancellable {
task.cancel()
cancellablesLock.sync {
cancellationCancellables[id]?.remove(cancellable)
if cancellationCancellables[id]?.isEmpty == .some(true) {
cancellationCancellables[id] = nil
}
let cancellable = AnyCancellable { task.cancel() }
cancellationCancellables[id, default: []].insert(cancellable)
return (cancellable, task)
}
defer {
cancellablesLock.sync {
cancellationCancellables[id]?.remove(cancellable)
if cancellationCancellables[id]?.isEmpty == .some(true) {
cancellationCancellables[id] = nil
}
}
cancellationCancellables[id, default: []].insert(cancellable)
cancellablesLock.unlock()
return task
}()
}
do {
return try await task.cancellableValue
} catch {
Expand Down Expand Up @@ -252,10 +249,8 @@ extension Task where Success == Never, Failure == Never {
/// Cancel any currently in-flight operation with the given identifier.
///
/// - Parameter id: An identifier.
public static func cancel<ID: Hashable & Sendable>(id: ID) async {
await MainActor.run {
cancellablesLock.sync { cancellationCancellables[.init(id: id)]?.forEach { $0.cancel() } }
}
public static func cancel<ID: Hashable & Sendable>(id: ID) {
cancellablesLock.sync { cancellationCancellables[.init(id: id)]?.forEach { $0.cancel() } }
}

/// Cancel any currently in-flight operation with the given identifier.
Expand All @@ -264,8 +259,8 @@ extension Task where Success == Never, Failure == Never {
/// identifier.
///
/// - Parameter id: A unique type identifying the operation.
public static func cancel(id: Any.Type) async {
await self.cancel(id: ObjectIdentifier(id))
public static func cancel(id: Any.Type) {
self.cancel(id: ObjectIdentifier(id))
}
}

Expand Down
4 changes: 2 additions & 2 deletions Tests/ComposableArchitectureTests/EffectRunTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ final class EffectRunTests: XCTestCase {
switch action {
case .tapped:
return .run { send in
await Task.cancel(id: CancelID.self)
Task.cancel(id: CancelID.self)
try Task.checkCancellation()
await send(.response)
}
Expand All @@ -101,7 +101,7 @@ final class EffectRunTests: XCTestCase {
switch action {
case .tapped:
return .run { send in
await Task.cancel(id: CancelID.self)
Task.cancel(id: CancelID.self)
try Task.checkCancellation()
await send(.responseA)
} catch: { @Sendable _, send in // NB: Explicit '@Sendable' required in 5.5.2
Expand Down
4 changes: 2 additions & 2 deletions Tests/ComposableArchitectureTests/EffectTaskTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ final class EffectTaskTests: XCTestCase {
switch action {
case .tapped:
return .task {
await Task.cancel(id: CancelID.self)
Task.cancel(id: CancelID.self)
try Task.checkCancellation()
return .response
}
Expand All @@ -101,7 +101,7 @@ final class EffectTaskTests: XCTestCase {
switch action {
case .tapped:
return .task {
await Task.cancel(id: CancelID.self)
Task.cancel(id: CancelID.self)
try Task.checkCancellation()
return .responseA
} catch: { @Sendable _ in // NB: Explicit '@Sendable' required in 5.5.2
Expand Down
24 changes: 21 additions & 3 deletions Tests/ComposableArchitectureTests/TaskCancellationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import XCTest

final class TaskCancellationTests: XCTestCase {
func testCancellation() async throws {
cancellationCancellables.removeAll()
cancellablesLock.sync {
cancellationCancellables.removeAll()
}
enum ID {}
let (stream, continuation) = AsyncStream<Void>.streamWithContinuation()
let task = Task {
Expand All @@ -16,12 +18,28 @@ final class TaskCancellationTests: XCTestCase {
}
}
await stream.first(where: { true })
await Task.cancel(id: ID.self)
XCTAssertEqual(cancellationCancellables, [:])
Task.cancel(id: ID.self)
await Task.megaYield(count: 20)
XCTAssertEqual(cancellablesLock.sync { cancellationCancellables }, [:])
do {
try await task.cancellableValue
XCTFail()
} catch {
}
}

func testWithTaskCancellationCleansUpTask() async throws {
let task = Task {
try await withTaskCancellation(id: 0) {
try await Task.sleep(nanoseconds: NSEC_PER_SEC * 1000)
}
}

try await Task.sleep(nanoseconds: NSEC_PER_SEC / 3)
XCTAssertEqual(cancellationCancellables.count, 1)

task.cancel()
try await Task.sleep(nanoseconds: NSEC_PER_SEC / 3)
XCTAssertEqual(cancellationCancellables.count, 0)
}
}