Skip to content

Commit b34b67c

Browse files
committed
mark waiting status only when we actually MustWait, while holding group lock
1 parent 13de654 commit b34b67c

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
378378
// ==== Status manipulation -------------------------------------------------
379379

380380
TaskGroupStatus statusLoadRelaxed() const;
381+
TaskGroupStatus statusLoadAcquire() const;
381382

382383
std::string statusString() const;
383384

@@ -409,6 +410,10 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
409410
/// Remove waiting status bit.
410411
TaskGroupStatus statusRemoveWaitingRelease();
411412

413+
/// Mark the waiting status bit.
414+
/// A waiting task MUST have been already enqueued in the `waitQueue`.
415+
TaskGroupStatus statusMarkWaitingAssumeRelease();
416+
412417
/// Cancels the group and returns true if was already cancelled before.
413418
/// After this function returns, the group is guaranteed to be cancelled.
414419
///
@@ -561,6 +566,10 @@ TaskGroupStatus TaskGroupBase::statusLoadRelaxed() const {
561566
return TaskGroupStatus{status.load(std::memory_order_relaxed)};
562567
}
563568

569+
TaskGroupStatus TaskGroupBase::statusLoadAcquire() const {
570+
return TaskGroupStatus{status.load(std::memory_order_acquire)};
571+
}
572+
564573
std::string TaskGroupBase::statusString() const {
565574
return statusLoadRelaxed().to_string(this);
566575
}
@@ -580,6 +589,12 @@ TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeAcquire() {
580589
return TaskGroupStatus{old | TaskGroupStatus::waiting};
581590
}
582591

592+
TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeRelease() {
593+
auto old = status.fetch_or(TaskGroupStatus::waiting,
594+
std::memory_order_release);
595+
return TaskGroupStatus{old | TaskGroupStatus::waiting};
596+
}
597+
583598
TaskGroupStatus TaskGroupBase::statusRemoveWaitingRelease() {
584599
auto old = status.fetch_and(~TaskGroupStatus::waiting,
585600
std::memory_order_release);
@@ -702,18 +717,6 @@ class DiscardingTaskGroup: public TaskGroupBase {
702717
return true;
703718
}
704719

705-
/// Returns *assumed* new status, including the just performed +1.
706-
TaskGroupStatus statusMarkWaitingAssumeAcquire() {
707-
auto old = status.fetch_or(TaskGroupStatus::waiting, std::memory_order_acquire);
708-
return TaskGroupStatus{old | TaskGroupStatus::waiting};
709-
}
710-
711-
TaskGroupStatus statusRemoveWaitingRelease() {
712-
auto old = status.fetch_and(~TaskGroupStatus::waiting,
713-
std::memory_order_release);
714-
return TaskGroupStatus{old};
715-
}
716-
717720
/// Returns *assumed* new status.
718721
TaskGroupStatus statusAddReadyAssumeAcquire(const DiscardingTaskGroup *group) {
719722
assert(group->isDiscardingResults());
@@ -1203,8 +1206,8 @@ void DiscardingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *context)
12031206
// Immediately decrement the pending count.
12041207
// We can do this, since in this mode there is no ready count to keep track of,
12051208
// and we immediately discard the result.
1206-
SWIFT_TASK_GROUP_DEBUG_LOG(this, "discard result, hadError:%d, was pending:%llu",
1207-
hadErrorResult, assumed.pendingTasks(this));
1209+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "discard result, hadError:%d, was pending:%llu, status = %s",
1210+
hadErrorResult, assumed.pendingTasks(this), assumed.to_string(this).c_str());
12081211
// If this was the last pending task, and there is a waiting task (from waitAll),
12091212
// we must resume the task; but not otherwise. There cannot be any waiters on next()
12101213
// while we're discarding results.
@@ -1662,9 +1665,9 @@ static void swift_taskGroup_waitAllImpl(
16621665
waitingTask, bodyError, group->statusString().c_str(), to_string(polled.status).c_str());
16631666

16641667
switch (polled.status) {
1665-
case PollStatus::MustWait:
1666-
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl MustWait, pending tasks exist, waiting task = %p",
1667-
waitingTask);
1668+
case PollStatus::MustWait: {
1669+
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl MustWait, pending tasks exist, waiting task = %p",
1670+
waitingTask);
16681671
// The waiting task has been queued on the channel,
16691672
// there were pending tasks so it will be woken up eventually.
16701673
#ifdef __ARM_ARCH_7K__
@@ -1673,8 +1676,9 @@ static void swift_taskGroup_waitAllImpl(
16731676
#else /* __ARM_ARCH_7K__ */
16741677
return;
16751678
#endif /* __ARM_ARCH_7K__ */
1679+
}
16761680

1677-
case PollStatus::Error:
1681+
case PollStatus::Error: {
16781682
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl Error, waiting task = %p, body error = %p, status:%s",
16791683
waitingTask, bodyError, group->statusString().c_str());
16801684
#if SWIFT_TASK_GROUP_BODY_THROWN_ERROR_WINS
@@ -1695,9 +1699,10 @@ static void swift_taskGroup_waitAllImpl(
16951699
}
16961700

16971701
return waitingTask->runInFullyEstablishedContext();
1702+
}
16981703

16991704
case PollStatus::Empty:
1700-
case PollStatus::Success:
1705+
case PollStatus::Success: {
17011706
/// Anything else than a "MustWait" can be treated as a successful poll.
17021707
/// Only if there are in flight pending tasks do we need to wait after all.
17031708
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl %s, waiting task = %p, status:%s",
@@ -1712,12 +1717,11 @@ static void swift_taskGroup_waitAllImpl(
17121717
}
17131718

17141719
return waitingTask->runInFullyEstablishedContext();
1720+
}
17151721
}
17161722
}
17171723

1718-
/// Must be called while holding the `taskGroup.lock`!
1719-
/// This is because the discarding task group still has some follow-up operations that must
1720-
/// be performed atomically after this operation sometimes, so we cannot unlock inside `waitAll` itself.
1724+
/// Caller must mark the `waiting` status bit when MustWait is returned.
17211725
PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask) {
17221726
lock(); // TODO: remove group lock, and use status for synchronization
17231727

@@ -1732,7 +1736,12 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17321736
bool haveRunOneChildTaskInline = false;
17331737

17341738
reevaluate_if_TaskGroup_has_results:;
1735-
auto assumed = statusMarkWaitingAssumeAcquire();
1739+
// Paired with a release when marking Waiting,
1740+
// otherwise we don't modify the status
1741+
auto assumed = statusLoadAcquire();
1742+
1743+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "waitAll, LOAD STATUS, status = %s",
1744+
assumed.to_string(this).c_str());
17361745

17371746
// ==== 1) may be able to bail out early if no tasks are pending -------------
17381747
if (assumed.isEmpty(this)) {
@@ -1759,7 +1768,6 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17591768
// No tasks in flight, we know no tasks were submitted before this poll
17601769
// was issued, and if we parked here we'd potentially never be woken up.
17611770
// Bail out and return `nil` from `group.next()`.
1762-
statusRemoveWaitingRelease();
17631771
unlock();
17641772
return result;
17651773
}
@@ -1787,6 +1795,7 @@ PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask)
17871795
waitHead, waitingTask,
17881796
/*success*/ std::memory_order_release,
17891797
/*failure*/ std::memory_order_acquire)) {
1798+
statusMarkWaitingAssumeRelease();
17901799
unlock(); // TODO: remove fragment lock, and use status for synchronization
17911800
#if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
17921801
// The logic here is paired with the logic in TaskGroupBase::offer. Once

0 commit comments

Comments
 (0)