Skip to content

Commit 3b4c14d

Browse files
committed
Make sure that cancelling a task group does not invoke cancellation
handler of parent task that created the group Change comment in TaskGroup.swift to enforce that only parent task can call cancelAll on the group Add tests to verify mutating of task group in child tasks will fail Radar-Id: rdar://problem/86346865
1 parent 469099f commit 3b4c14d

File tree

7 files changed

+86
-103
lines changed

7 files changed

+86
-103
lines changed

include/swift/ABI/TaskGroup.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define SWIFT_ABI_TASK_GROUP_H
1919

2020
#include "swift/ABI/Task.h"
21+
#include "swift/ABI/TaskStatus.h"
2122
#include "swift/ABI/HeapObject.h"
2223
#include "swift/Runtime/Concurrency.h"
2324
#include "swift/Runtime/Config.h"
@@ -46,6 +47,9 @@ class alignas(Alignment_TaskGroup) TaskGroup {
4647
// Add a child task to the group. Always called with the status record lock of
4748
// the parent task held
4849
void addChildTask(AsyncTask *task);
50+
51+
// Provide accessor for task group's status record
52+
TaskGroupTaskStatusRecord *getTaskRecord();
4953
};
5054

5155
} // end namespace swift

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,10 @@ static TaskGroup *asAbstract(TaskGroupImpl *group) {
449449
return reinterpret_cast<TaskGroup*>(group);
450450
}
451451

452+
TaskGroupTaskStatusRecord * TaskGroup::getTaskRecord() {
453+
return asImpl(this)->getTaskRecord();
454+
}
455+
452456
// =============================================================================
453457
// ==== initialize -------------------------------------------------------------
454458

stdlib/public/Concurrency/TaskGroup.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ public struct TaskGroup<ChildTaskResult: Sendable> {
373373
/// If you add a task to a group after canceling the group,
374374
/// that task is canceled immediately after being added to the group.
375375
///
376-
/// There are no restrictions on where you can call this method.
377-
/// Code inside a child task or even another task can cancel a group.
376+
/// This method can only be called by the parent task that created the task
377+
/// group.
378378
///
379379
/// - SeeAlso: `Task.isCancelled`
380380
/// - SeeAlso: `TaskGroup.isCancelled`

stdlib/public/Concurrency/TaskStatus.cpp

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -535,41 +535,6 @@ static void performCancellationAction(TaskStatusRecord *record) {
535535
// FIXME: allow dynamic extension/correction?
536536
}
537537

538-
/// Perform any cancellation actions required by the given record.
539-
static void performGroupCancellationAction(TaskStatusRecord *record) {
540-
switch (record->getKind()) {
541-
// We only need to cancel specific GroupChildTasks, not arbitrary child tasks.
542-
// A task may be parent to many tasks which are not part of a group after all.
543-
case TaskStatusRecordKind::ChildTask:
544-
return;
545-
546-
case TaskStatusRecordKind::TaskGroup: {
547-
auto groupChildRecord = cast<TaskGroupTaskStatusRecord>(record);
548-
// Since a task can only be running a single task group at the same time,
549-
// we can always assume that the group record which we found is the one
550-
// we're intended to cancel child tasks for.
551-
//
552-
// A group enforces that tasks can not "escape" it, and as such once the group
553-
// returns, all its task have been completed.
554-
for (AsyncTask *child: groupChildRecord->children()) {
555-
swift_task_cancel(child);
556-
}
557-
return;
558-
}
559-
560-
// All other kinds of records we handle the same way as in a normal cancellation
561-
case TaskStatusRecordKind::Deadline:
562-
case TaskStatusRecordKind::CancellationNotification:
563-
case TaskStatusRecordKind::EscalationNotification:
564-
case TaskStatusRecordKind::Private_RecordLock:
565-
performCancellationAction(record);
566-
return;
567-
}
568-
569-
// Other cases can fall through here and be ignored.
570-
// FIXME: allow dynamic extension/correction?
571-
}
572-
573538
SWIFT_CC(swift)
574539
static void swift_task_cancelImpl(AsyncTask *task) {
575540
SWIFT_TASK_DEBUG_LOG("cancel task = %p", task);
@@ -608,20 +573,16 @@ static void swift_task_cancel_group_child_tasksImpl(TaskGroup *group) {
608573

609574
// Acquire the status record lock.
610575
//
611-
// We purposefully DO NOT make this a cancellation by itself.
612-
// We are cancelling the task group, and all tasks it contains.
613-
// We are NOT cancelling the entire parent task though.
576+
// Guaranteed to be called from the context of the parent task that created
577+
// the task group once we have #40616
614578
auto task = swift_task_getCurrent();
615-
auto oldStatus = acquireStatusRecordLock(task, recordLockRecord,
616-
LockContext::OnTask);
617-
// Carry out the cancellation operations associated with all
618-
// the active records.
619-
for (auto cur: oldStatus.records()) {
620-
performGroupCancellationAction(cur);
621-
}
622-
623-
// Release the status record lock, restoring exactly the old status.
624-
releaseStatusRecordLock(task, oldStatus, recordLockRecord);
579+
withStatusRecordLock(task, LockContext::OnTask,
580+
[&](ActiveTaskStatus &status) {
581+
// We purposefully DO NOT make this a cancellation by itself.
582+
// We are cancelling the task group, and all tasks it contains.
583+
// We are NOT cancelling the entire parent task though.
584+
performCancellationAction(group->getTaskRecord());
585+
});
625586
}
626587

627588
/**************************************************************************/

test/Concurrency/Runtime/async_taskgroup_cancel_from_inside_child.swift

Lines changed: 0 additions & 53 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %target-run-simple-swift( -Xfrontend -disable-availability-checking %import-libdispatch -parse-as-library) | %FileCheck %s
2+
3+
// REQUIRES: executable_test
4+
// REQUIRES: concurrency
5+
// REQUIRES: libdispatch
6+
7+
// rdar://76038845
8+
// REQUIRES: concurrency_runtime
9+
// UNSUPPORTED: back_deployment_runtime
10+
11+
func test_taskGroup_cancelAll() async {
12+
13+
await withTaskCancellationHandler {
14+
await withTaskGroup(of: Int.self, returning: Void.self) { group in
15+
group.spawn {
16+
await Task.sleep(3_000_000_000)
17+
let c = Task.isCancelled
18+
print("group task isCancelled: \(c)")
19+
return 0
20+
}
21+
22+
group.cancelAll() // Cancels the group but not the task
23+
_ = await group.next()
24+
}
25+
} onCancel : {
26+
print("parent task cancel handler called")
27+
}
28+
29+
// CHECK-NOT: parent task cancel handler called
30+
// CHECK: group task isCancelled: true
31+
// CHECK: done
32+
print("done")
33+
}
34+
35+
@available(SwiftStdlib 5.1, *)
36+
@main struct Main {
37+
static func main() async {
38+
await test_taskGroup_cancelAll()
39+
}
40+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %target-typecheck-verify-swift -disable-availability-checking
2+
// REQUIRES: concurrency
3+
4+
@available(SwiftStdlib 5.1, *)
5+
func test_taskGroup_cancelAll() async {
6+
await withTaskGroup(of: Int.self, returning: Void.self) { group in
7+
group.spawn {
8+
await Task.sleep(3_000_000_000)
9+
let c = Task.isCancelled
10+
print("group task isCancelled: \(c)")
11+
return 0
12+
}
13+
14+
group.spawn {
15+
group.cancelAll() //expected-warning{{capture of 'group' with non-sendable type 'TaskGroup<Int>' in a `@Sendable` closure}}
16+
//expected-error@-1{{reference to captured parameter 'group' in concurrently-executing code}}
17+
return 0
18+
}
19+
group.spawn { [group] in
20+
group.cancelAll() //expected-warning{{capture of 'group' with non-sendable type 'TaskGroup<Int>' in a `@Sendable` closure}}
21+
return 0
22+
}
23+
_ = await group.next()
24+
}
25+
26+
print("done")
27+
}

0 commit comments

Comments
 (0)