Skip to content

Commit bf57e39

Browse files
authored
🍒 [Concurrency] prevent races in task cancellation (#38425)
* [Concurrency] prevent races in task cancellation * Temporarily disable test async_task_cancellation_early.swift
1 parent 921e8b5 commit bf57e39

File tree

5 files changed

+113
-71
lines changed

5 files changed

+113
-71
lines changed

include/swift/ABI/TaskStatus.h

Lines changed: 33 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#ifndef SWIFT_ABI_TASKSTATUS_H
2121
#define SWIFT_ABI_TASKSTATUS_H
2222

23-
#include "swift/ABI/Task.h"
2423
#include "swift/ABI/MetadataValues.h"
24+
#include "swift/ABI/Task.h"
2525

2626
namespace swift {
2727

@@ -30,7 +30,7 @@ namespace swift {
3030
/// TaskStatusRecords are typically allocated on the stack (possibly
3131
/// in the task context), partially initialized, and then atomically
3232
/// added to the task with `swift_task_addTaskStatusRecord`. While
33-
/// registered with the task, a status record should only be
33+
/// registered with the task, a status record should only be
3434
/// modified in ways that respect the possibility of asynchronous
3535
/// access by a cancelling thread. In particular, the chain of
3636
/// status records must not be disturbed. When the task leaves
@@ -51,13 +51,9 @@ class TaskStatusRecord {
5151
TaskStatusRecord(const TaskStatusRecord &) = delete;
5252
TaskStatusRecord &operator=(const TaskStatusRecord &) = delete;
5353

54-
TaskStatusRecordKind getKind() const {
55-
return Flags.getKind();
56-
}
54+
TaskStatusRecordKind getKind() const { return Flags.getKind(); }
5755

58-
TaskStatusRecord *getParent() const {
59-
return Parent;
60-
}
56+
TaskStatusRecord *getParent() const { return Parent; }
6157

6258
/// Change the parent of this unregistered status record to the
6359
/// given record.
@@ -77,9 +73,7 @@ class TaskStatusRecord {
7773
/// Unlike resetParent, this assumes that it's just removing one or
7874
/// more records from the chain and that there's no need to do any
7975
/// extra cache manipulation.
80-
void spliceParent(TaskStatusRecord *newParent) {
81-
Parent = newParent;
82-
}
76+
void spliceParent(TaskStatusRecord *newParent) { Parent = newParent; }
8377
};
8478

8579
/// A deadline for the task. If this is reached, the task will be
@@ -102,14 +96,12 @@ struct TaskDeadline {
10296
/// within the task.
10397
class DeadlineStatusRecord : public TaskStatusRecord {
10498
TaskDeadline Deadline;
99+
105100
public:
106101
DeadlineStatusRecord(TaskDeadline deadline)
107-
: TaskStatusRecord(TaskStatusRecordKind::Deadline),
108-
Deadline(deadline) {}
102+
: TaskStatusRecord(TaskStatusRecordKind::Deadline), Deadline(deadline) {}
109103

110-
TaskDeadline getDeadline() const {
111-
return Deadline;
112-
}
104+
TaskDeadline getDeadline() const { return Deadline; }
113105

114106
static bool classof(const TaskStatusRecord *record) {
115107
return record->getKind() == TaskStatusRecordKind::Deadline;
@@ -123,25 +115,22 @@ class ChildTaskStatusRecord : public TaskStatusRecord {
123115

124116
public:
125117
ChildTaskStatusRecord(AsyncTask *child)
126-
: TaskStatusRecord(TaskStatusRecordKind::ChildTask),
127-
FirstChild(child) {}
118+
: TaskStatusRecord(TaskStatusRecordKind::ChildTask), FirstChild(child) {}
128119

129120
ChildTaskStatusRecord(AsyncTask *child, TaskStatusRecordKind kind)
130-
: TaskStatusRecord(kind),
131-
FirstChild(child) {
121+
: TaskStatusRecord(kind), FirstChild(child) {
132122
assert(kind == TaskStatusRecordKind::ChildTask);
133123
assert(!child->hasGroupChildFragment() &&
134-
"Group child tasks must be tracked in their respective "
135-
"TaskGroupTaskStatusRecord, and not as independent ChildTaskStatusRecord "
136-
"records.");
124+
"Group child tasks must be tracked in their respective "
125+
"TaskGroupTaskStatusRecord, and not as independent "
126+
"ChildTaskStatusRecord "
127+
"records.");
137128
}
138129

139130
/// Return the first child linked by this record. This may be null;
140131
/// if not, it (and all of its successors) are guaranteed to satisfy
141132
/// `isChildTask()`.
142-
AsyncTask *getFirstChild() const {
143-
return FirstChild;
144-
}
133+
AsyncTask *getFirstChild() const { return FirstChild; }
145134

146135
static AsyncTask *getNextChildTask(AsyncTask *task) {
147136
return task->childFragment()->getNextChild();
@@ -175,25 +164,21 @@ class ChildTaskStatusRecord : public TaskStatusRecord {
175164
/// and are only tracked by their respective `TaskGroupTaskStatusRecord`.
176165
class TaskGroupTaskStatusRecord : public TaskStatusRecord {
177166
AsyncTask *FirstChild;
167+
178168
public:
179169
TaskGroupTaskStatusRecord()
180-
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
181-
FirstChild(nullptr) {}
170+
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup), FirstChild(nullptr) {
171+
}
182172

183173
TaskGroupTaskStatusRecord(AsyncTask *child)
184-
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
185-
FirstChild(child) {}
174+
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup), FirstChild(child) {}
186175

187-
TaskGroup* getGroup() {
188-
return reinterpret_cast<TaskGroup *>(this);
189-
}
176+
TaskGroup *getGroup() { return reinterpret_cast<TaskGroup *>(this); }
190177

191178
/// Return the first child linked by this record. This may be null;
192179
/// if not, it (and all of its successors) are guaranteed to satisfy
193180
/// `isChildTask()`.
194-
AsyncTask *getFirstChild() const {
195-
return FirstChild;
196-
}
181+
AsyncTask *getFirstChild() const { return FirstChild; }
197182

198183
/// Attach the passed in `child` task to this group.
199184
void attachChild(AsyncTask *child) {
@@ -207,7 +192,8 @@ class TaskGroupTaskStatusRecord : public TaskStatusRecord {
207192
return;
208193
}
209194

210-
// We need to traverse the siblings to find the last one and add the child there.
195+
// We need to traverse the siblings to find the last one and add the child
196+
// there.
211197
// FIXME: just set prepend to the current head, no need to traverse.
212198

213199
auto cur = FirstChild;
@@ -249,20 +235,18 @@ class TaskGroupTaskStatusRecord : public TaskStatusRecord {
249235
/// subsequently used.
250236
class CancellationNotificationStatusRecord : public TaskStatusRecord {
251237
public:
252-
using FunctionType = SWIFT_CC(swift) void (SWIFT_CONTEXT void *);
238+
using FunctionType = SWIFT_CC(swift) void(SWIFT_CONTEXT void *);
253239

254240
private:
255-
FunctionType * __ptrauth_swift_cancellation_notification_function Function;
241+
FunctionType *__ptrauth_swift_cancellation_notification_function Function;
256242
void *Argument;
257243

258244
public:
259245
CancellationNotificationStatusRecord(FunctionType *fn, void *arg)
260-
: TaskStatusRecord(TaskStatusRecordKind::CancellationNotification),
261-
Function(fn), Argument(arg) {}
246+
: TaskStatusRecord(TaskStatusRecordKind::CancellationNotification),
247+
Function(fn), Argument(arg) {}
262248

263-
void run() {
264-
Function(Argument);
265-
}
249+
void run() { Function(Argument); }
266250

267251
static bool classof(const TaskStatusRecord *record) {
268252
return record->getKind() == TaskStatusRecordKind::CancellationNotification;
@@ -279,20 +263,18 @@ class CancellationNotificationStatusRecord : public TaskStatusRecord {
279263
/// subsequently used.
280264
class EscalationNotificationStatusRecord : public TaskStatusRecord {
281265
public:
282-
using FunctionType = void (void *, JobPriority);
266+
using FunctionType = void(void *, JobPriority);
283267

284268
private:
285-
FunctionType * __ptrauth_swift_escalation_notification_function Function;
269+
FunctionType *__ptrauth_swift_escalation_notification_function Function;
286270
void *Argument;
287271

288272
public:
289273
EscalationNotificationStatusRecord(FunctionType *fn, void *arg)
290-
: TaskStatusRecord(TaskStatusRecordKind::EscalationNotification),
291-
Function(fn), Argument(arg) {}
274+
: TaskStatusRecord(TaskStatusRecordKind::EscalationNotification),
275+
Function(fn), Argument(arg) {}
292276

293-
void run(JobPriority newPriority) {
294-
Function(Argument, newPriority);
295-
}
277+
void run(JobPriority newPriority) { Function(Argument, newPriority); }
296278

297279
static bool classof(const TaskStatusRecord *record) {
298280
return record->getKind() == TaskStatusRecordKind::EscalationNotification;

stdlib/public/Concurrency/Task.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,12 @@ swift_task_addCancellationHandlerImpl(
10291029
auto *record = new (allocation)
10301030
CancellationNotificationStatusRecord(unsigned_handler, context);
10311031

1032-
swift_task_addStatusRecord(record);
1032+
if (swift_task_addStatusRecord(record))
1033+
return record;
1034+
1035+
// else, the task was already cancelled, so while the record was added,
1036+
// we must run it immediately here since no other task will trigger it.
1037+
record->run();
10331038
return record;
10341039
}
10351040

stdlib/public/Concurrency/TaskCancellation.swift

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,8 @@ public func withTaskCancellationHandler<T>(
3434
) async rethrows -> T {
3535
let task = Builtin.getCurrentAsyncTask()
3636

37-
guard !_taskIsCancelled(task) else {
38-
// If the current task is already cancelled, run the handler immediately.
39-
handler()
40-
return try await operation()
41-
}
42-
37+
// unconditionally add the cancellation record to the task.
38+
// if the task was already cancelled, it will be executed right away.
4339
let record = _taskAddCancellationHandler(handler: handler)
4440
defer { _taskRemoveCancellationHandler(record: record) }
4541

test/Concurrency/Runtime/async_task_cancellation_early.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency %import-libdispatch -parse-as-library)
1+
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency %import-libdispatch -parse-as-library) | %FileCheck %s --dump-input=always
22

33
// REQUIRES: executable_test
44
// REQUIRES: concurrency
55
// REQUIRES: libdispatch
66

7+
// Temporarily disabled to unblock PR testing:
8+
// REQUIRES: rdar80745964
9+
710
// rdar://76038845
811
// UNSUPPORTED: use_os_stdlib
912
// UNSUPPORTED: back_deployment_runtime
@@ -21,14 +24,11 @@ func test_detach_cancel_child_early() async {
2124

2225
let xx = await childCancelled
2326
print("child, cancelled: \(xx)") // CHECK: child, cancelled: true
24-
let cancelled = Task.isCancelled
25-
print("self, cancelled: \(cancelled )") // CHECK: self, cancelled: true
27+
let cancelled = Task.isCancelled
28+
print("self, cancelled: \(cancelled)") // CHECK: self, cancelled: true
2629
return cancelled
2730
}
2831

29-
// no sleep here -- this confirms that the child task `x`
30-
// carries the cancelled flag, as it is started from a cancelled task.
31-
3232
h.cancel()
3333
print("handle cancel")
3434
let got = try! await h.value

test/Concurrency/Runtime/async_task_cancellation_while_running.swift

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency %import-libdispatch -parse-as-library)
1+
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency %import-libdispatch -parse-as-library) | %FileCheck %s
22

33
// REQUIRES: executable_test
44
// REQUIRES: concurrency
@@ -10,11 +10,13 @@
1010

1111
import Dispatch
1212

13+
let seconds: UInt64 = 1_000_000_000
14+
1315
@available(SwiftStdlib 5.5, *)
1416
func test_detach_cancel_while_child_running() async {
15-
let h: Task<Bool, Error> = detach {
17+
let task: Task<Bool, Error> = Task.detached {
1618
async let childCancelled: Bool = { () -> Bool in
17-
await Task.sleep(3_000_000_000)
19+
await Task.sleep(3 * seconds)
1820
return Task.isCancelled
1921
}()
2022

@@ -26,17 +28,74 @@ func test_detach_cancel_while_child_running() async {
2628
}
2729

2830
// sleep here, i.e. give the task a moment to start running
29-
await Task.sleep(2_000_000_000)
31+
await Task.sleep(2 * seconds)
32+
33+
task.cancel()
34+
print("task.cancel()")
35+
let got = try! await task.get()
36+
print("was cancelled: \(got)") // CHECK: was cancelled: true
37+
}
38+
39+
@available(SwiftStdlib 5.5, *)
40+
func test_cancel_while_withTaskCancellationHandler_inflight() async {
41+
let task: Task<Bool, Error> = Task.detached {
42+
await withTaskCancellationHandler {
43+
await Task.sleep(2 * seconds)
44+
print("operation-1")
45+
await Task.sleep(1 * seconds)
46+
print("operation-2")
47+
return Task.isCancelled
48+
} onCancel: {
49+
print("onCancel")
50+
}
51+
52+
}
53+
54+
await Task.sleep(1 * seconds)
55+
56+
// CHECK: task.cancel()
57+
// CHECK: onCancel
58+
// CHECK: operation-1
59+
// CHECK: operation-2
60+
print("task.cancel()")
61+
task.cancel()
62+
let got = try! await task.get()
63+
print("was cancelled: \(got)") // CHECK: was cancelled: true
64+
}
65+
66+
@available(SwiftStdlib 5.5, *)
67+
func test_cancel_while_withTaskCancellationHandler_onlyOnce() async {
68+
let task: Task<Bool, Error> = Task.detached {
69+
await withTaskCancellationHandler {
70+
await Task.sleep(2 * seconds)
71+
await Task.sleep(2 * seconds)
72+
await Task.sleep(2 * seconds)
73+
print("operation-done")
74+
return Task.isCancelled
75+
} onCancel: {
76+
print("onCancel")
77+
}
78+
}
79+
80+
await Task.sleep(1 * seconds)
3081

31-
h.cancel()
32-
print("handle cancel")
33-
let got = try! await h.get()
82+
// CHECK: task.cancel()
83+
// CHECK: onCancel
84+
// onCancel runs only once, even though we attempt to cancel the task many times
85+
// CHECK-NEXT: operation-done
86+
print("task.cancel()")
87+
task.cancel()
88+
task.cancel()
89+
task.cancel()
90+
let got = try! await task.get()
3491
print("was cancelled: \(got)") // CHECK: was cancelled: true
3592
}
3693

3794
@available(SwiftStdlib 5.5, *)
3895
@main struct Main {
3996
static func main() async {
4097
await test_detach_cancel_while_child_running()
98+
await test_cancel_while_withTaskCancellationHandler_inflight()
99+
await test_cancel_while_withTaskCancellationHandler_onlyOnce()
41100
}
42101
}

0 commit comments

Comments
 (0)