Skip to content

Commit 4bd1099

Browse files
committed
prepare for cancellation handling
1 parent 268dec1 commit 4bd1099

File tree

7 files changed

+46
-30
lines changed

7 files changed

+46
-30
lines changed

include/swift/Runtime/Concurrency.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ bool swift_taskGroup_isDiscardingResults(TaskGroup *group);
307307
/// \code
308308
/// func swift_taskGroup_waitAll(
309309
/// waitingTask: Builtin.NativeObject, // current task
310-
/// group: Builtin.RawPointer
310+
/// group: Builtin.RawPointer,
311+
/// childFailureCancelsGroup: Bool
311312
/// ) async throws
312313
/// \endcode
313314
SWIFT_EXPORT_FROM(swift_Concurrency)
@@ -316,6 +317,7 @@ bool swift_taskGroup_isDiscardingResults(TaskGroup *group);
316317
OpaqueValue *resultPointer,
317318
SWIFT_ASYNC_CONTEXT AsyncContext *callerContext,
318319
TaskGroup *group,
320+
bool childFailureCancelsGroup,
319321
ThrowingTaskFutureWaitContinuationFunction *resumeFn,
320322
AsyncContext *callContext);
321323

stdlib/public/CompatibilityOverride/CompatibilityOverrideConcurrency.def

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,11 @@ OVERRIDE_TASK_GROUP(taskGroup_waitAll, void,
326326
(OpaqueValue *resultPointer,
327327
SWIFT_ASYNC_CONTEXT AsyncContext *callerContext,
328328
TaskGroup *_group,
329+
bool childFailureCancelsGroup,
329330
ThrowingTaskFutureWaitContinuationFunction *resumeFn,
330331
AsyncContext *callContext),
331-
(resultPointer, callerContext, _group, resumeFn,
332-
callContext))
332+
(resultPointer, callerContext, _group, childFailureCancelsGroup,
333+
resumeFn, callContext))
333334

334335
OVERRIDE_TASK_LOCAL(task_reportIllegalTaskLocalBindingWithinWithTaskGroup, void,
335336
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift), swift::,

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ using namespace swift;
5454
/*************************** TASK GROUP ***************************************/
5555
/******************************************************************************/
5656

57-
#if 1
57+
#if 0
5858
#define SWIFT_TASK_GROUP_DEBUG_LOG(group, fmt, ...) \
5959
fprintf(stderr, "[%#lx] [%s:%d](%s) group(%p%s) " fmt "\n", \
6060
(unsigned long)Thread::current().platformThreadId(), \
@@ -956,6 +956,7 @@ __attribute__((noinline))
956956
SWIFT_CC(swiftasync) static void workaround_function_swift_taskGroup_waitAllImpl(
957957
OpaqueValue *result, SWIFT_ASYNC_CONTEXT AsyncContext *callerContext,
958958
TaskGroup *_group,
959+
bool childFailureCancelsGroup,
959960
ThrowingTaskFutureWaitContinuationFunction resumeFunction,
960961
AsyncContext *callContext) {
961962
// Make sure we don't eliminate calls to this function.
@@ -1187,6 +1188,7 @@ SWIFT_CC(swiftasync)
11871188
static void swift_taskGroup_waitAllImpl(
11881189
OpaqueValue *resultPointer, SWIFT_ASYNC_CONTEXT AsyncContext *callerContext,
11891190
TaskGroup *_group,
1191+
bool childFailureCancelsGroup,
11901192
ThrowingTaskFutureWaitContinuationFunction *resumeFunction,
11911193
AsyncContext *rawContext) {
11921194
auto waitingTask = swift_task_getCurrent();
@@ -1213,7 +1215,7 @@ static void swift_taskGroup_waitAllImpl(
12131215
// there were pending tasks so it will be woken up eventually.
12141216
#ifdef __ARM_ARCH_7K__
12151217
return workaround_function_swift_taskGroup_waitAllImpl(
1216-
resultPointer, callerContext, _group, resumeFunction, rawContext);
1218+
resultPointer, callerContext, _group, childFailureCancelsGroup, resumeFunction, rawContext);
12171219
#else /* __ARM_ARCH_7K__ */
12181220
return;
12191221
#endif /* __ARM_ARCH_7K__ */

stdlib/public/Concurrency/TaskGroup.swift

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ public func withTaskGroup<ChildTaskResult, GroupResult>(
110110
// Run the withTaskGroup body.
111111
let result = await body(&group)
112112

113-
let _: ChildTaskResult? = try? await _taskGroupWaitAll(group: _group) // try!-safe, cannot throw since this is a non throwing group
113+
let _: ChildTaskResult? = try? await _taskGroupWaitAll(group: _group, childFailureCancelsGroup: discardResults) // try!-safe, cannot throw since this is a non throwing group
114114
return result
115115
}
116116

@@ -198,14 +198,14 @@ public func withThrowingTaskGroup<ChildTaskResult, GroupResult>(
198198
// Run the withTaskGroup body.
199199
let result = try await body(&group)
200200

201-
_ = try? await group.awaitAllRemainingTasks()
201+
await group.awaitAllRemainingTasks()
202202
Builtin.destroyTaskGroup(_group)
203203

204204
return result
205205
} catch {
206206
group.cancelAll()
207207

208-
_ = try? await group.awaitAllRemainingTasks() // discard errors
208+
await group.awaitAllRemainingTasks()
209209
Builtin.destroyTaskGroup(_group)
210210

211211
throw error
@@ -231,28 +231,24 @@ public func withThrowingTaskGroup<ChildTaskResult, GroupResult>(
231231

232232
let _group = Builtin.createTaskGroupWithFlags(flags, ChildTaskResult.self)
233233
var group = ThrowingTaskGroup<ChildTaskResult, Error>(group: _group)
234+
defer { Builtin.destroyTaskGroup(_group) }
234235

236+
let result: GroupResult
235237
do {
236238
// Run the withTaskGroup body.
237-
let result = try await body(&group)
238-
239-
try await group.awaitAllRemainingTasks()
240-
Builtin.destroyTaskGroup(_group)
241-
242-
return result
239+
result = try await body(&group)
243240
} catch {
244241
group.cancelAll()
245242

246-
do {
247-
try await group.awaitAllRemainingTasks()
248-
Builtin.destroyTaskGroup(_group)
249-
} catch {
250-
Builtin.destroyTaskGroup(_group)
251-
throw error
252-
}
243+
await group.awaitAllRemainingTasks()
253244

254245
throw error
255246
}
247+
248+
// FIXME: if one of them throws, cancel the group
249+
try await group.awaitAllRemainingTasksThrowing(childFailureCancelsGroup: true)
250+
251+
return result
256252
}
257253

258254
/// A group that contains dynamically created child tasks.
@@ -514,7 +510,7 @@ public struct TaskGroup<ChildTaskResult: Sendable> {
514510
/// implementation.
515511
if #available(SwiftStdlib 5.8, *) {
516512
if isDiscardingResults {
517-
let _: ChildTaskResult? = try! await _taskGroupWaitAll(group: _group) // try!-safe, cannot throw, not throwing group
513+
let _: ChildTaskResult? = try! await _taskGroupWaitAll(group: _group, childFailureCancelsGroup: isDiscardingResults) // try!-safe, cannot throw, not throwing group
518514
return
519515
}
520516
}
@@ -641,13 +637,22 @@ public struct ThrowingTaskGroup<ChildTaskResult: Sendable, Failure: Error> {
641637

642638
/// Await all the remaining tasks on this group.
643639
@usableFromInline
644-
internal mutating func awaitAllRemainingTasks() async throws {
640+
@available(*, deprecated, message: "Use `awaitAllRemainingTasksThrowing`, since 5.8 with discardResults draining may throw")
641+
internal mutating func awaitAllRemainingTasks() async {
642+
// We discard the error because in old code, which may have inlined this `awaitAllRemainingTasks`
643+
// method, draining was never going to throw
644+
_ = try? await awaitAllRemainingTasksThrowing(childFailureCancelsGroup: false)
645+
}
646+
647+
/// Await all the remaining tasks on this group.
648+
@usableFromInline
649+
internal mutating func awaitAllRemainingTasksThrowing(childFailureCancelsGroup: Bool) async throws {
645650
/// Since 5.8, we implement "wait for all pending tasks to complete"
646651
/// in the runtime, in order to be able to handle the discard-results
647652
/// implementation.
648653
if #available(SwiftStdlib 5.8, *) {
649654
if isDiscardingResults {
650-
let _: ChildTaskResult? = try await _taskGroupWaitAll(group: _group) // if any of the tasks throws, this will "rethrow" here
655+
let _: ChildTaskResult? = try await _taskGroupWaitAll(group: _group, childFailureCancelsGroup: childFailureCancelsGroup) // if any of the tasks throws, this will "rethrow" here
651656
return
652657
}
653658
}
@@ -674,7 +679,7 @@ public struct ThrowingTaskGroup<ChildTaskResult: Sendable, Failure: Error> {
674679
/// - Throws: only during
675680
@_alwaysEmitIntoClient
676681
public mutating func waitForAll() async throws {
677-
try await self.awaitAllRemainingTasks()
682+
try await self.awaitAllRemainingTasksThrowing(childFailureCancelsGroup: false)
678683
}
679684

680685
#if !SWIFT_STDLIB_TASK_TO_THREAD_MODEL_CONCURRENCY
@@ -1207,7 +1212,7 @@ func _taskHasTaskGroupStatusRecord() -> Bool
12071212
@usableFromInline
12081213
@discardableResult
12091214
@_silgen_name("swift_taskGroup_waitAll")
1210-
func _taskGroupWaitAll<T>(group: Builtin.RawPointer) async throws -> T?
1215+
func _taskGroupWaitAll<T>(group: Builtin.RawPointer, childFailureCancelsGroup: Bool) async throws -> T?
12111216

12121217
@available(SwiftStdlib 5.8, *)
12131218
@_silgen_name("swift_taskGroup_isDiscardingResults")

stdlib/public/Concurrency/TaskPrivate.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ namespace {
137137
/// func _taskGroupWaitNext<T>(group: Builtin.RawPointer) async throws -> T?
138138
///
139139
/// @_silgen_name("swift_taskGroup_waitAll")
140-
/// func _taskGroupWaitAll<T>(group: Builtin.RawPointer) async throws -> T?
140+
/// func _taskGroupWaitAll<T>(
141+
/// group: Builtin.RawPointer,
142+
/// childFailureCancelsGroup: Bool
143+
/// ) async throws -> T?
141144
///
142145
class TaskFutureWaitAsyncContext : public AsyncContext {
143146
public:

test/Concurrency/Runtime/async_taskgroup_throw_rethrow.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,17 @@ func test_taskGroup_discardResults_automaticallyRethrowsOnlyFirst() async {
9393
let got = try await withThrowingTaskGroup(of: Int.self, returning: Int.self,
9494
discardResults: true) { group in
9595
group.addTask { await echo(1) }
96-
group.addTask { throw Boom(id: "first") }
96+
group.addTask { throw Boom(id: "first-a") }
97+
group.addTask { throw Boom(id: "first-b") }
9798
// add a throwing task, but don't consume it explicitly
9899
// since we're in discard results mode, all will be awaited and the first error it thrown
99100

100101
do {
101102
try await group.waitForAll()
102103
} catch {
103-
// CHECK: caught: Boom(id: "first")
104+
// There's no guarantee about which of the `first-...` tasks will complete first,
105+
// however, they all will be consumed when we have returned from the `waitForAll`.
106+
// CHECK: caught: Boom(id: "first
104107
print("caught: \(error)")
105108
}
106109

unittests/runtime/CompatibilityOverrideConcurrency.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ TEST_F(CompatibilityOverrideConcurrencyTest, test_swift_taskGroup_cancelAll) {
216216
}
217217

218218
TEST_F(CompatibilityOverrideConcurrencyTest, test_swift_taskGroup_waitAll) {
219-
swift_taskGroup_waitAll(nullptr, nullptr, nullptr, nullptr,
219+
swift_taskGroup_waitAll(nullptr, nullptr, nullptr, false, nullptr,
220220
nullptr);
221221
}
222222

0 commit comments

Comments
 (0)