Skip to content

🍒 [Concurrency] prevent races in task cancellation #38425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 33 additions & 51 deletions include/swift/ABI/TaskStatus.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#ifndef SWIFT_ABI_TASKSTATUS_H
#define SWIFT_ABI_TASKSTATUS_H

#include "swift/ABI/Task.h"
#include "swift/ABI/MetadataValues.h"
#include "swift/ABI/Task.h"

namespace swift {

Expand All @@ -30,7 +30,7 @@ namespace swift {
/// TaskStatusRecords are typically allocated on the stack (possibly
/// in the task context), partially initialized, and then atomically
/// added to the task with `swift_task_addTaskStatusRecord`. While
/// registered with the task, a status record should only be
/// registered with the task, a status record should only be
/// modified in ways that respect the possibility of asynchronous
/// access by a cancelling thread. In particular, the chain of
/// status records must not be disturbed. When the task leaves
Expand All @@ -51,13 +51,9 @@ class TaskStatusRecord {
TaskStatusRecord(const TaskStatusRecord &) = delete;
TaskStatusRecord &operator=(const TaskStatusRecord &) = delete;

TaskStatusRecordKind getKind() const {
return Flags.getKind();
}
TaskStatusRecordKind getKind() const { return Flags.getKind(); }

TaskStatusRecord *getParent() const {
return Parent;
}
TaskStatusRecord *getParent() const { return Parent; }

/// Change the parent of this unregistered status record to the
/// given record.
Expand All @@ -77,9 +73,7 @@ class TaskStatusRecord {
/// Unlike resetParent, this assumes that it's just removing one or
/// more records from the chain and that there's no need to do any
/// extra cache manipulation.
void spliceParent(TaskStatusRecord *newParent) {
Parent = newParent;
}
void spliceParent(TaskStatusRecord *newParent) { Parent = newParent; }
};

/// A deadline for the task. If this is reached, the task will be
Expand All @@ -102,14 +96,12 @@ struct TaskDeadline {
/// within the task.
class DeadlineStatusRecord : public TaskStatusRecord {
TaskDeadline Deadline;

public:
DeadlineStatusRecord(TaskDeadline deadline)
: TaskStatusRecord(TaskStatusRecordKind::Deadline),
Deadline(deadline) {}
: TaskStatusRecord(TaskStatusRecordKind::Deadline), Deadline(deadline) {}

TaskDeadline getDeadline() const {
return Deadline;
}
TaskDeadline getDeadline() const { return Deadline; }

static bool classof(const TaskStatusRecord *record) {
return record->getKind() == TaskStatusRecordKind::Deadline;
Expand All @@ -123,25 +115,22 @@ class ChildTaskStatusRecord : public TaskStatusRecord {

public:
ChildTaskStatusRecord(AsyncTask *child)
: TaskStatusRecord(TaskStatusRecordKind::ChildTask),
FirstChild(child) {}
: TaskStatusRecord(TaskStatusRecordKind::ChildTask), FirstChild(child) {}

ChildTaskStatusRecord(AsyncTask *child, TaskStatusRecordKind kind)
: TaskStatusRecord(kind),
FirstChild(child) {
: TaskStatusRecord(kind), FirstChild(child) {
assert(kind == TaskStatusRecordKind::ChildTask);
assert(!child->hasGroupChildFragment() &&
"Group child tasks must be tracked in their respective "
"TaskGroupTaskStatusRecord, and not as independent ChildTaskStatusRecord "
"records.");
"Group child tasks must be tracked in their respective "
"TaskGroupTaskStatusRecord, and not as independent "
"ChildTaskStatusRecord "
"records.");
}

/// Return the first child linked by this record. This may be null;
/// if not, it (and all of its successors) are guaranteed to satisfy
/// `isChildTask()`.
AsyncTask *getFirstChild() const {
return FirstChild;
}
AsyncTask *getFirstChild() const { return FirstChild; }

static AsyncTask *getNextChildTask(AsyncTask *task) {
return task->childFragment()->getNextChild();
Expand Down Expand Up @@ -175,25 +164,21 @@ class ChildTaskStatusRecord : public TaskStatusRecord {
/// and are only tracked by their respective `TaskGroupTaskStatusRecord`.
class TaskGroupTaskStatusRecord : public TaskStatusRecord {
AsyncTask *FirstChild;

public:
TaskGroupTaskStatusRecord()
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
FirstChild(nullptr) {}
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup), FirstChild(nullptr) {
}

TaskGroupTaskStatusRecord(AsyncTask *child)
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
FirstChild(child) {}
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup), FirstChild(child) {}

TaskGroup* getGroup() {
return reinterpret_cast<TaskGroup *>(this);
}
TaskGroup *getGroup() { return reinterpret_cast<TaskGroup *>(this); }

/// Return the first child linked by this record. This may be null;
/// if not, it (and all of its successors) are guaranteed to satisfy
/// `isChildTask()`.
AsyncTask *getFirstChild() const {
return FirstChild;
}
AsyncTask *getFirstChild() const { return FirstChild; }

/// Attach the passed in `child` task to this group.
void attachChild(AsyncTask *child) {
Expand All @@ -207,7 +192,8 @@ class TaskGroupTaskStatusRecord : public TaskStatusRecord {
return;
}

// We need to traverse the siblings to find the last one and add the child there.
// We need to traverse the siblings to find the last one and add the child
// there.
// FIXME: just set prepend to the current head, no need to traverse.

auto cur = FirstChild;
Expand Down Expand Up @@ -249,20 +235,18 @@ class TaskGroupTaskStatusRecord : public TaskStatusRecord {
/// subsequently used.
class CancellationNotificationStatusRecord : public TaskStatusRecord {
public:
using FunctionType = SWIFT_CC(swift) void (SWIFT_CONTEXT void *);
using FunctionType = SWIFT_CC(swift) void(SWIFT_CONTEXT void *);

private:
FunctionType * __ptrauth_swift_cancellation_notification_function Function;
FunctionType *__ptrauth_swift_cancellation_notification_function Function;
void *Argument;

public:
CancellationNotificationStatusRecord(FunctionType *fn, void *arg)
: TaskStatusRecord(TaskStatusRecordKind::CancellationNotification),
Function(fn), Argument(arg) {}
: TaskStatusRecord(TaskStatusRecordKind::CancellationNotification),
Function(fn), Argument(arg) {}

void run() {
Function(Argument);
}
void run() { Function(Argument); }

static bool classof(const TaskStatusRecord *record) {
return record->getKind() == TaskStatusRecordKind::CancellationNotification;
Expand All @@ -279,20 +263,18 @@ class CancellationNotificationStatusRecord : public TaskStatusRecord {
/// subsequently used.
class EscalationNotificationStatusRecord : public TaskStatusRecord {
public:
using FunctionType = void (void *, JobPriority);
using FunctionType = void(void *, JobPriority);

private:
FunctionType * __ptrauth_swift_escalation_notification_function Function;
FunctionType *__ptrauth_swift_escalation_notification_function Function;
void *Argument;

public:
EscalationNotificationStatusRecord(FunctionType *fn, void *arg)
: TaskStatusRecord(TaskStatusRecordKind::EscalationNotification),
Function(fn), Argument(arg) {}
: TaskStatusRecord(TaskStatusRecordKind::EscalationNotification),
Function(fn), Argument(arg) {}

void run(JobPriority newPriority) {
Function(Argument, newPriority);
}
void run(JobPriority newPriority) { Function(Argument, newPriority); }

static bool classof(const TaskStatusRecord *record) {
return record->getKind() == TaskStatusRecordKind::EscalationNotification;
Expand Down
7 changes: 6 additions & 1 deletion stdlib/public/Concurrency/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,12 @@ swift_task_addCancellationHandlerImpl(
auto *record = new (allocation)
CancellationNotificationStatusRecord(unsigned_handler, context);

swift_task_addStatusRecord(record);
if (swift_task_addStatusRecord(record))
return record;

// else, the task was already cancelled, so while the record was added,
// we must run it immediately here since no other task will trigger it.
record->run();
return record;
}

Expand Down
8 changes: 2 additions & 6 deletions stdlib/public/Concurrency/TaskCancellation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@ public func withTaskCancellationHandler<T>(
) async rethrows -> T {
let task = Builtin.getCurrentAsyncTask()

guard !_taskIsCancelled(task) else {
// If the current task is already cancelled, run the handler immediately.
handler()
return try await operation()
}

// unconditionally add the cancellation record to the task.
// if the task was already cancelled, it will be executed right away.
let record = _taskAddCancellationHandler(handler: handler)
defer { _taskRemoveCancellationHandler(record: record) }

Expand Down
12 changes: 6 additions & 6 deletions test/Concurrency/Runtime/async_task_cancellation_early.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency %import-libdispatch -parse-as-library)
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency %import-libdispatch -parse-as-library) | %FileCheck %s --dump-input=always

// REQUIRES: executable_test
// REQUIRES: concurrency
// REQUIRES: libdispatch

// Temporarily disabled to unblock PR testing:
// REQUIRES: rdar80745964

// rdar://76038845
// UNSUPPORTED: use_os_stdlib
// UNSUPPORTED: back_deployment_runtime
Expand All @@ -21,14 +24,11 @@ func test_detach_cancel_child_early() async {

let xx = await childCancelled
print("child, cancelled: \(xx)") // CHECK: child, cancelled: true
let cancelled = Task.isCancelled
print("self, cancelled: \(cancelled )") // CHECK: self, cancelled: true
let cancelled = Task.isCancelled
print("self, cancelled: \(cancelled)") // CHECK: self, cancelled: true
return cancelled
}

// no sleep here -- this confirms that the child task `x`
// carries the cancelled flag, as it is started from a cancelled task.

h.cancel()
print("handle cancel")
let got = try! await h.value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency %import-libdispatch -parse-as-library)
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency %import-libdispatch -parse-as-library) | %FileCheck %s

// REQUIRES: executable_test
// REQUIRES: concurrency
Expand All @@ -10,11 +10,13 @@

import Dispatch

let seconds: UInt64 = 1_000_000_000

@available(SwiftStdlib 5.5, *)
func test_detach_cancel_while_child_running() async {
let h: Task<Bool, Error> = detach {
let task: Task<Bool, Error> = Task.detached {
async let childCancelled: Bool = { () -> Bool in
await Task.sleep(3_000_000_000)
await Task.sleep(3 * seconds)
return Task.isCancelled
}()

Expand All @@ -26,17 +28,74 @@ func test_detach_cancel_while_child_running() async {
}

// sleep here, i.e. give the task a moment to start running
await Task.sleep(2_000_000_000)
await Task.sleep(2 * seconds)

task.cancel()
print("task.cancel()")
let got = try! await task.get()
print("was cancelled: \(got)") // CHECK: was cancelled: true
}

@available(SwiftStdlib 5.5, *)
func test_cancel_while_withTaskCancellationHandler_inflight() async {
let task: Task<Bool, Error> = Task.detached {
await withTaskCancellationHandler {
await Task.sleep(2 * seconds)
print("operation-1")
await Task.sleep(1 * seconds)
print("operation-2")
return Task.isCancelled
} onCancel: {
print("onCancel")
}

}

await Task.sleep(1 * seconds)

// CHECK: task.cancel()
// CHECK: onCancel
// CHECK: operation-1
// CHECK: operation-2
print("task.cancel()")
task.cancel()
let got = try! await task.get()
print("was cancelled: \(got)") // CHECK: was cancelled: true
}

@available(SwiftStdlib 5.5, *)
func test_cancel_while_withTaskCancellationHandler_onlyOnce() async {
let task: Task<Bool, Error> = Task.detached {
await withTaskCancellationHandler {
await Task.sleep(2 * seconds)
await Task.sleep(2 * seconds)
await Task.sleep(2 * seconds)
print("operation-done")
return Task.isCancelled
} onCancel: {
print("onCancel")
}
}

await Task.sleep(1 * seconds)

h.cancel()
print("handle cancel")
let got = try! await h.get()
// CHECK: task.cancel()
// CHECK: onCancel
// onCancel runs only once, even though we attempt to cancel the task many times
// CHECK-NEXT: operation-done
print("task.cancel()")
task.cancel()
task.cancel()
task.cancel()
let got = try! await task.get()
print("was cancelled: \(got)") // CHECK: was cancelled: true
}

@available(SwiftStdlib 5.5, *)
@main struct Main {
static func main() async {
await test_detach_cancel_while_child_running()
await test_cancel_while_withTaskCancellationHandler_inflight()
await test_cancel_while_withTaskCancellationHandler_onlyOnce()
}
}