Skip to content

[Concurrency] Refine getResumeFunctionForLogging to avoid reading invalid future contexts. #73796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/swift/ABI/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,11 @@ class AsyncTask : public Job {
/// failing that will return ResumeTask. The returned function pointer may
/// have a different signature than ResumeTask, and it's only for identifying
/// code associated with the task.
const void *getResumeFunctionForLogging();
///
/// If isStarting is true, look into the resume context when appropriate
/// to pull out a wrapped resume function. If isStarting is false, assume the
/// resume context may not be valid and just return the wrapper.
const void *getResumeFunctionForLogging(bool isStarting);

/// Given that we've already fully established the job context
/// in the current thread, start running this task. To establish
Expand Down
17 changes: 11 additions & 6 deletions stdlib/public/Concurrency/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ const void
*const swift::_swift_concurrency_debug_task_future_wait_resume_adapter =
reinterpret_cast<void *>(task_future_wait_resume_adapter);

const void *AsyncTask::getResumeFunctionForLogging() {
const void *AsyncTask::getResumeFunctionForLogging(bool isStarting) {
const void *result = reinterpret_cast<const void *>(ResumeTask);

if (ResumeTask == non_future_adapter) {
Expand All @@ -553,11 +553,16 @@ const void *AsyncTask::getResumeFunctionForLogging() {
sizeof(FutureAsyncContextPrefix));
result =
reinterpret_cast<const void *>(asyncContextPrefix->asyncEntryPoint);
} else if (ResumeTask == task_wait_throwing_resume_adapter) {
auto context = static_cast<TaskFutureWaitAsyncContext *>(ResumeContext);
result = reinterpret_cast<const void *>(context->ResumeParent);
} else if (ResumeTask == task_future_wait_resume_adapter) {
result = reinterpret_cast<const void *>(ResumeContext->ResumeParent);
}

// Future contexts may not be valid if the task was already running before.
if (isStarting) {
if (ResumeTask == task_wait_throwing_resume_adapter) {
auto context = static_cast<TaskFutureWaitAsyncContext *>(ResumeContext);
result = reinterpret_cast<const void *>(context->ResumeParent);
} else if (ResumeTask == task_future_wait_resume_adapter) {
result = reinterpret_cast<const void *>(ResumeContext->ResumeParent);
}
}

return __ptrauth_swift_runtime_function_entry_strip(result);
Expand Down
6 changes: 3 additions & 3 deletions stdlib/public/Concurrency/TaskPrivate.h
Original file line number Diff line number Diff line change
Expand Up @@ -709,10 +709,10 @@ class alignas(2 * sizeof(void*)) ActiveTaskStatus {
return record_iterator::rangeBeginning(getInnermostRecord());
}

void traceStatusChanged(AsyncTask *task) {
void traceStatusChanged(AsyncTask *task, bool isStarting) {
concurrency::trace::task_status_changed(
task, static_cast<uint8_t>(getStoredPriority()), isCancelled(),
isStoredPriorityEscalated(), isRunning(), isEnqueued());
isStoredPriorityEscalated(), isStarting, isRunning(), isEnqueued());
}
};

Expand Down Expand Up @@ -938,7 +938,7 @@ inline void AsyncTask::flagAsRunning() {
if (_private()._status().compare_exchange_weak(oldStatus, newStatus,
/* success */ std::memory_order_relaxed,
/* failure */ std::memory_order_relaxed)) {
newStatus.traceStatusChanged(this);
newStatus.traceStatusChanged(this, true);
adoptTaskVoucher(this);
swift_task_enterThreadLocalContext(
(char *)&_private().ExclusivityAccessSet[0]);
Expand Down
14 changes: 7 additions & 7 deletions stdlib/public/Concurrency/TaskStatus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ static bool withStatusRecordLock(AsyncTask *task, ActiveTaskStatus status,

status = newStatus;

status.traceStatusChanged(task);
status.traceStatusChanged(task, false);
worker.flagQueueIsPublished(lockingRecord);
installedLockRecord = true;

Expand Down Expand Up @@ -268,7 +268,7 @@ static bool withStatusRecordLock(AsyncTask *task, ActiveTaskStatus status,
if (task->_private()._status().compare_exchange_weak(status, newStatus,
/*success*/ std::memory_order_release,
/*failure*/ std::memory_order_relaxed)) {
newStatus.traceStatusChanged(task);
newStatus.traceStatusChanged(task, false);
break;
}
}
Expand Down Expand Up @@ -322,7 +322,7 @@ bool swift::addStatusRecord(AsyncTask *task, TaskStatusRecord *newRecord,
if (task->_private()._status().compare_exchange_weak(oldStatus, newStatus,
/*success*/ std::memory_order_release,
/*failure*/ std::memory_order_relaxed)) {
newStatus.traceStatusChanged(task);
newStatus.traceStatusChanged(task, false);
return true;
} else {
// Retry
Expand Down Expand Up @@ -404,7 +404,7 @@ void swift::removeStatusRecord(AsyncTask *task, TaskStatusRecord *record,
if (task->_private()._status().compare_exchange_weak(oldStatus, newStatus,
/*success*/ std::memory_order_relaxed,
/*failure*/ std::memory_order_relaxed)) {
newStatus.traceStatusChanged(task);
newStatus.traceStatusChanged(task, false);
return;
}
}
Expand Down Expand Up @@ -436,7 +436,7 @@ void swift::removeStatusRecord(AsyncTask *task, TaskStatusRecord *record,
if (task->_private()._status().compare_exchange_weak(oldStatus, newStatus,
/*success*/ std::memory_order_relaxed,
/*failure*/ std::memory_order_relaxed)) {
newStatus.traceStatusChanged(task);
newStatus.traceStatusChanged(task, false);
return;
}
// Restart the loop again - someone else modified status concurrently
Expand Down Expand Up @@ -494,7 +494,7 @@ void swift::removeStatusRecordWhere(
if (task->_private()._status().compare_exchange_weak(oldStatus, newStatus,
/*success*/ std::memory_order_relaxed,
/*failure*/ std::memory_order_relaxed)) {
newStatus.traceStatusChanged(task);
newStatus.traceStatusChanged(task, false);
return;
}
}
Expand Down Expand Up @@ -904,7 +904,7 @@ static void swift_task_cancelImpl(AsyncTask *task) {
}
}

newStatus.traceStatusChanged(task);
newStatus.traceStatusChanged(task, false);
if (newStatus.getInnermostRecord() == NULL) {
// No records, nothing to propagate
return;
Expand Down
2 changes: 1 addition & 1 deletion stdlib/public/Concurrency/Tracing.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void task_create(AsyncTask *task, AsyncTask *parent, TaskGroup *group,
void task_destroy(AsyncTask *task);

void task_status_changed(AsyncTask *task, uint8_t maxPriority, bool isCancelled,
bool isEscalated, bool isRunning, bool isEnqueued);
bool isEscalated, bool isStarting, bool isRunning, bool isEnqueued);

void task_flags_changed(AsyncTask *task, uint8_t jobPriority, bool isChildTask,
bool isFuture, bool isGroupChildTask,
Expand Down
6 changes: 3 additions & 3 deletions stdlib/public/Concurrency/TracingSignpost.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ inline void task_create(AsyncTask *task, AsyncTask *parent, TaskGroup *group,
" resumefn=%p jobPriority=%u isChildTask=%{bool}d, isFuture=%{bool}d "
"isGroupChildTask=%{bool}d isAsyncLetTask=%{bool}d parent=%" PRIx64
" group=%p asyncLet=%p",
task->getTaskId(), task->getResumeFunctionForLogging(), jobPriority,
task->getTaskId(), task->getResumeFunctionForLogging(true), jobPriority,
isChildTask, isFuture, isGroupChildTask, isAsyncLetTask, parentID, group,
asyncLet);
}
Expand All @@ -206,15 +206,15 @@ inline void task_destroy(AsyncTask *task) {

inline void task_status_changed(AsyncTask *task, uint8_t maxPriority,
bool isCancelled, bool isEscalated,
bool isRunning, bool isEnqueued) {
bool isStarting, bool isRunning, bool isEnqueued) {
ENSURE_LOGS();
auto id = os_signpost_id_make_with_pointer(TaskLog, task);
os_signpost_event_emit(
TaskLog, id, SWIFT_LOG_TASK_STATUS_CHANGED_NAME,
"task=%" PRIx64 " resumefn=%p "
"maxPriority=%u, isCancelled=%{bool}d "
"isEscalated=%{bool}d, isRunning=%{bool}d, isEnqueued=%{bool}d",
task->getTaskId(), task->getResumeFunctionForLogging(), maxPriority,
task->getTaskId(), task->getResumeFunctionForLogging(isStarting), maxPriority,
isCancelled, isEscalated, isRunning, isEnqueued);
}

Expand Down
2 changes: 1 addition & 1 deletion stdlib/public/Concurrency/TracingStubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ inline void task_resume(AsyncTask *task) {}

inline void task_status_changed(AsyncTask *task, uint8_t maxPriority,
bool isCancelled, bool isEscalated,
bool isRunning, bool isEnqueued) {}
bool isStarting, bool isRunning, bool isEnqueued) {}

inline void task_flags_changed(AsyncTask *task, uint8_t jobPriority,
bool isChildTask, bool isFuture,
Expand Down