Skip to content

Commit c87ebdc

Browse files
committed
Must not modify waitingTask context outside lock
1 parent 13de654 commit c87ebdc

File tree

1 file changed

+50
-39
lines changed

1 file changed

+50
-39
lines changed

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,9 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
366366
///
367367
/// \param bodyError error thrown by the body of a with...TaskGroup method
368368
/// \param waitingTask the task waiting on the group
369+
/// \param rawContext used to resume the waiting task
369370
/// \return how the waiting task should be handled, e.g. must wait or can be completed immediately
370-
PollResult waitAll(SwiftError* bodyError, AsyncTask *waitingTask);
371+
PollResult waitAll(SwiftError* bodyError, AsyncTask *waitingTask, AsyncContext* rawContext);
371372

372373
// Enqueue the completed task onto ready queue if there are no waiting tasks yet
373374
virtual void enqueueCompletedTask(AsyncTask *completedTask, bool hadErrorResult) = 0;
@@ -378,6 +379,7 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
378379
// ==== Status manipulation -------------------------------------------------
379380

380381
TaskGroupStatus statusLoadRelaxed() const;
382+
TaskGroupStatus statusLoadAcquire() const;
381383

382384
std::string statusString() const;
383385

@@ -409,6 +411,10 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
409411
/// Remove waiting status bit.
410412
TaskGroupStatus statusRemoveWaitingRelease();
411413

414+
/// Mark the waiting status bit.
415+
/// A waiting task MUST have been already enqueued in the `waitQueue`.
416+
TaskGroupStatus statusMarkWaitingAssumeRelease();
417+
412418
/// Cancels the group and returns true if was already cancelled before.
413419
/// After this function returns, the group is guaranteed to be cancelled.
414420
///
@@ -521,7 +527,7 @@ struct TaskGroupStatus {
521527
/// TaskGroupStatus{ C:{cancelled} W:{waiting task} R:{ready tasks} P:{pending tasks} {binary repr} }
522528
/// If discarding results:
523529
/// TaskGroupStatus{ C:{cancelled} W:{waiting task} P:{pending tasks} {binary repr} }
524-
std::string to_string(const TaskGroupBase* _Nonnull group) {
530+
std::string to_string(const TaskGroupBase* group) {
525531
std::string str;
526532
str.append("TaskGroupStatus{ ");
527533
str.append("C:"); // cancelled
@@ -548,7 +554,7 @@ struct TaskGroupStatus {
548554
bool TaskGroupBase::statusCompletePendingReadyWaiting(TaskGroupStatus &old) {
549555
return status.compare_exchange_strong(
550556
old.status, old.completingPendingReadyWaiting(this).status,
551-
/*success*/ std::memory_order_relaxed,
557+
/*success*/ std::memory_order_release,
552558
/*failure*/ std::memory_order_relaxed);
553559
}
554560

@@ -561,6 +567,10 @@ TaskGroupStatus TaskGroupBase::statusLoadRelaxed() const {
561567
return TaskGroupStatus{status.load(std::memory_order_relaxed)};
562568
}
563569

570+
TaskGroupStatus TaskGroupBase::statusLoadAcquire() const {
571+
return TaskGroupStatus{status.load(std::memory_order_acquire)};
572+
}
573+
564574
std::string TaskGroupBase::statusString() const {
565575
return statusLoadRelaxed().to_string(this);
566576
}
@@ -580,6 +590,12 @@ TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeAcquire() {
580590
return TaskGroupStatus{old | TaskGroupStatus::waiting};
581591
}
582592

593+
TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeRelease() {
594+
auto old = status.fetch_or(TaskGroupStatus::waiting,
595+
std::memory_order_release);
596+
return TaskGroupStatus{old | TaskGroupStatus::waiting};
597+
}
598+
583599
TaskGroupStatus TaskGroupBase::statusRemoveWaitingRelease() {
584600
auto old = status.fetch_and(~TaskGroupStatus::waiting,
585601
std::memory_order_release);
@@ -702,18 +718,6 @@ class DiscardingTaskGroup: public TaskGroupBase {
702718
return true;
703719
}
704720

705-
/// Returns *assumed* new status, including the just performed +1.
706-
TaskGroupStatus statusMarkWaitingAssumeAcquire() {
707-
auto old = status.fetch_or(TaskGroupStatus::waiting, std::memory_order_acquire);
708-
return TaskGroupStatus{old | TaskGroupStatus::waiting};
709-
}
710-
711-
TaskGroupStatus statusRemoveWaitingRelease() {
712-
auto old = status.fetch_and(~TaskGroupStatus::waiting,
713-
std::memory_order_release);
714-
return TaskGroupStatus{old};
715-
}
716-
717721
/// Returns *assumed* new status.
718722
TaskGroupStatus statusAddReadyAssumeAcquire(const DiscardingTaskGroup *group) {
719723
assert(group->isDiscardingResults());
@@ -1145,7 +1149,7 @@ void AccumulatingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *contex
11451149
hadErrorResult = true;
11461150
}
11471151

1148-
SWIFT_TASK_GROUP_DEBUG_LOG(this, "ready: %d, pending: %u",
1152+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "ready: %d, pending: %llu",
11491153
assumed.readyTasks(this), assumed.pendingTasks(this));
11501154

11511155
// ==== a) has waiting task, so let us complete it right away
@@ -1198,13 +1202,14 @@ void DiscardingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *context)
11981202

11991203
/// If we're the last task we've been waiting for, and there is a waiting task on the group
12001204
bool lastPendingTaskAndWaitingTask =
1201-
assumed.pendingTasks(this) == 1 && assumed.hasWaitingTask();
1205+
assumed.pendingTasks(this) == 1 &&
1206+
assumed.hasWaitingTask();
12021207

12031208
// Immediately decrement the pending count.
12041209
// We can do this, since in this mode there is no ready count to keep track of,
12051210
// and we immediately discard the result.
1206-
SWIFT_TASK_GROUP_DEBUG_LOG(this, "discard result, hadError:%d, was pending:%llu",
1207-
hadErrorResult, assumed.pendingTasks(this));
1211+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "discard result, hadError:%d, was pending:%llu, status = %s",
1212+
hadErrorResult, assumed.pendingTasks(this), assumed.to_string(this).c_str());
12081213
// If this was the last pending task, and there is a waiting task (from waitAll),
12091214
// we must resume the task; but not otherwise. There cannot be any waiters on next()
12101215
// while we're discarding results.
@@ -1294,6 +1299,8 @@ void TaskGroupBase::resumeWaitingTask(
12941299
if (statusCompletePendingReadyWaiting(assumed)) {
12951300
// Run the task.
12961301
auto result = PollResult::get(completedTask, hadErrorResult);
1302+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "resume waiting DONE, task = %p, complete with = %p, status = %s",
1303+
waitingTask, completedTask, statusString().c_str());
12971304

12981305
// Remove the child from the task group's running tasks list.
12991306
// The parent task isn't currently running (we're about to wake
@@ -1645,11 +1652,9 @@ static void swift_taskGroup_waitAllImpl(
16451652
ThrowingTaskFutureWaitContinuationFunction *resumeFunction,
16461653
AsyncContext *rawContext) {
16471654
auto waitingTask = swift_task_getCurrent();
1648-
waitingTask->ResumeTask = task_group_wait_resume_adapter;
1649-
waitingTask->ResumeContext = rawContext;
16501655

16511656
auto group = asBaseImpl(_group);
1652-
PollResult polled = group->waitAll(bodyError, waitingTask);
1657+
PollResult polled = group->waitAll(bodyError, waitingTask, rawContext);
16531658

16541659
auto context = static_cast<TaskFutureWaitAsyncContext *>(rawContext);
16551660
context->ResumeParent =
@@ -1662,19 +1667,17 @@ static void swift_taskGroup_waitAllImpl(
16621667
waitingTask, bodyError, group->statusString().c_str(), to_string(polled.status).c_str());
16631668

16641669
switch (polled.status) {
1665-
case PollStatus::MustWait:
1666-
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl MustWait, pending tasks exist, waiting task = %p",
1667-
waitingTask);
1670+
case PollStatus::MustWait: {
16681671
// The waiting task has been queued on the channel,
16691672
// there were pending tasks so it will be woken up eventually.
16701673
#ifdef __ARM_ARCH_7K__
1671-
return workaround_function_swift_taskGroup_waitAllImpl(
1674+
workaround_function_swift_taskGroup_waitAllImpl(
16721675
resultPointer, callerContext, _group, bodyError, resumeFunction, rawContext);
1673-
#else /* __ARM_ARCH_7K__ */
1674-
return;
16751676
#endif /* __ARM_ARCH_7K__ */
1677+
return;
1678+
}
16761679

1677-
case PollStatus::Error:
1680+
case PollStatus::Error: {
16781681
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl Error, waiting task = %p, body error = %p, status:%s",
16791682
waitingTask, bodyError, group->statusString().c_str());
16801683
#if SWIFT_TASK_GROUP_BODY_THROWN_ERROR_WINS
@@ -1695,9 +1698,10 @@ static void swift_taskGroup_waitAllImpl(
16951698
}
16961699

16971700
return waitingTask->runInFullyEstablishedContext();
1701+
}
16981702

16991703
case PollStatus::Empty:
1700-
case PollStatus::Success:
1704+
case PollStatus::Success: {
17011705
/// Anything else than a "MustWait" can be treated as a successful poll.
17021706
/// Only if there are in flight pending tasks do we need to wait after all.
17031707
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl %s, waiting task = %p, status:%s",
@@ -1712,14 +1716,17 @@ static void swift_taskGroup_waitAllImpl(
17121716
}
17131717

17141718
return waitingTask->runInFullyEstablishedContext();
1719+
}
17151720
}
17161721
}
17171722

1718-
/// Must be called while holding the `taskGroup.lock`!
1719-
/// This is because the discarding task group still has some follow-up operations that must
1720-
/// be performed atomically after this operation sometimes, so we cannot unlock inside `waitAll` itself.
1721-
PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask) {
1722-
lock(); // TODO: remove group lock, and use status for synchronization
1723+
PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask, AsyncContext *rawContext) {
1724+
lock();
1725+
1726+
// must mutate the waiting task while holding the group lock,
1727+
// so we don't get an offer concurrently trying to do so
1728+
waitingTask->ResumeTask = task_group_wait_resume_adapter;
1729+
waitingTask->ResumeContext = rawContext;
17231730

17241731
SWIFT_TASK_GROUP_DEBUG_LOG(this, "waitAll, bodyError = %p, status = %s", bodyError, statusString().c_str());
17251732
PollResult result = PollResult::getEmpty(this->successType);
@@ -1732,7 +1739,11 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17321739
bool haveRunOneChildTaskInline = false;
17331740

17341741
reevaluate_if_TaskGroup_has_results:;
1735-
auto assumed = statusMarkWaitingAssumeAcquire();
1742+
// Paired with a release when marking Waiting,
1743+
// otherwise we don't modify the status
1744+
auto assumed = statusLoadAcquire();
1745+
1746+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "waitAll, status = %s", assumed.to_string(this).c_str());
17361747

17371748
// ==== 1) may be able to bail out early if no tasks are pending -------------
17381749
if (assumed.isEmpty(this)) {
@@ -1750,7 +1761,6 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17501761
result.status = PollStatus::Error;
17511762
}
17521763
} // else, we're definitely Empty
1753-
17541764
unlock();
17551765
return result;
17561766
}
@@ -1759,7 +1769,6 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17591769
// No tasks in flight, we know no tasks were submitted before this poll
17601770
// was issued, and if we parked here we'd potentially never be woken up.
17611771
// Bail out and return `nil` from `group.next()`.
1762-
statusRemoveWaitingRelease();
17631772
unlock();
17641773
return result;
17651774
}
@@ -1787,7 +1796,9 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17871796
waitHead, waitingTask,
17881797
/*success*/ std::memory_order_release,
17891798
/*failure*/ std::memory_order_acquire)) {
1790-
unlock(); // TODO: remove fragment lock, and use status for synchronization
1799+
statusMarkWaitingAssumeRelease();
1800+
unlock();
1801+
17911802
#if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
17921803
// The logic here is paired with the logic in TaskGroupBase::offer. Once
17931804
// we run the

0 commit comments

Comments
 (0)