Skip to content

Commit 95628b1

Browse files
committed
Offering body error must be done while holding lock
1 parent b929dac commit 95628b1

File tree

1 file changed

+50
-30
lines changed

1 file changed

+50
-30
lines changed

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,24 @@
5151
using namespace swift;
5252

5353
#if 0
54-
#define SWIFT_TASK_GROUP_DEBUG_LOG(group, fmt, ...) \
54+
#define SWIFT_TASK_GROUP_DEBUG_LOG(group, fmt, ...) \
5555
fprintf(stderr, "[%#lx] [%s:%d][group(%p%s)] (%s) " fmt "\n", \
5656
(unsigned long)Thread::current().platformThreadId(), \
5757
__FILE__, __LINE__, \
5858
group, group->isDiscardingResults() ? ",discardResults" : "", \
5959
__FUNCTION__, \
6060
__VA_ARGS__)
61+
62+
#define SWIFT_TASK_GROUP_DEBUG_LOG_0(group, fmt, ...) \
63+
fprintf(stderr, "[%#lx] [%s:%d][group(%p)] (%s) " fmt "\n", \
64+
(unsigned long)Thread::current().platformThreadId(), \
65+
__FILE__, __LINE__, \
66+
group, \
67+
__FUNCTION__, \
68+
__VA_ARGS__)
6169
#else
6270
#define SWIFT_TASK_GROUP_DEBUG_LOG(group, fmt, ...) (void)0
71+
#define SWIFT_TASK_GROUP_DEBUG_LOG_0(group, fmt, ...) (void)0
6372
#endif
6473

6574
using FutureFragment = AsyncTask::FutureFragment;
@@ -354,7 +363,11 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
354363
/// There can be only at-most-one waiting task on a group at any given time,
355364
/// and the waiting task is expected to be the parent task in which the group
356365
/// body is running.
357-
PollResult waitAll(AsyncTask *waitingTask);
366+
///
367+
/// \param bodyError error thrown by the body of a with...TaskGroup method
368+
/// \param waitingTask the task waiting on the group
369+
/// \return how the waiting task should be handled, e.g. must wait or can be completed immediately
370+
PollResult waitAll(SwiftError* bodyError, AsyncTask *waitingTask);
358371

359372
// Enqueue the completed task onto ready queue if there are no waiting tasks yet
360373
virtual void enqueueCompletedTask(AsyncTask *completedTask, bool hadErrorResult) = 0;
@@ -411,6 +424,15 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
411424
virtual TaskGroupStatus statusAddPendingTaskRelaxed(bool unconditionally) = 0;
412425
};
413426

427+
static std::string to_string(TaskGroupBase::PollStatus status) {
428+
switch (status) {
429+
case TaskGroupBase::PollStatus::Empty: return "Empty";
430+
case TaskGroupBase::PollStatus::MustWait: return "MustWait";
431+
case TaskGroupBase::PollStatus::Success: return "Success";
432+
case TaskGroupBase::PollStatus::Error: return "Error";
433+
}
434+
}
435+
414436
/// The status of a task group.
415437
///
416438
/// Its exact structure depends on the type of group, and therefore a group must be passed to operations
@@ -855,8 +877,8 @@ static void swift_taskGroup_initializeWithFlagsImpl(size_t rawGroupFlags,
855877
TaskGroup *group, const Metadata *T) {
856878

857879
TaskGroupFlags groupFlags(rawGroupFlags);
858-
SWIFT_TASK_DEBUG_LOG("group(%p) create; flags: isDiscardingResults=%d",
859-
group, groupFlags.isDiscardResults());
880+
SWIFT_TASK_GROUP_DEBUG_LOG_0(group, "create group; flags: isDiscardingResults=%d",
881+
groupFlags.isDiscardResults());
860882

861883
TaskGroupBase *impl;
862884
if (groupFlags.isDiscardResults()) {
@@ -1618,7 +1640,7 @@ static void swift_taskGroup_waitAllImpl(
16181640
waitingTask->ResumeContext = rawContext;
16191641

16201642
auto group = asBaseImpl(_group);
1621-
PollResult polled = group->waitAll(waitingTask);
1643+
PollResult polled = group->waitAll(bodyError, waitingTask);
16221644

16231645
auto context = static_cast<TaskFutureWaitAsyncContext *>(rawContext);
16241646
context->ResumeParent =
@@ -1627,23 +1649,13 @@ static void swift_taskGroup_waitAllImpl(
16271649
context->errorResult = nullptr;
16281650
context->successResultPointer = resultPointer;
16291651

1630-
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl, waiting task = %p, bodyError = %p, status:%s",
1631-
waitingTask, bodyError, group->statusString().c_str());
1652+
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl, waiting task = %p, bodyError = %p, status:%s, polled.status = %s",
1653+
waitingTask, bodyError, group->statusString().c_str(), to_string(polled.status).c_str());
16321654

16331655
switch (polled.status) {
16341656
case PollStatus::MustWait:
1635-
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAll MustWait, pending tasks exist, waiting task = %p",
1657+
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl MustWait, pending tasks exist, waiting task = %p",
16361658
waitingTask);
1637-
if (bodyError && group->isDiscardingResults()) {
1638-
auto discardingGroup = asDiscardingImpl(_group);
1639-
bool storedBodyError = discardingGroup->offerBodyError(bodyError);
1640-
if (storedBodyError) {
1641-
SWIFT_TASK_GROUP_DEBUG_LOG(
1642-
group, "waitAll, stored error thrown by with...Group body, error = %p",
1643-
bodyError);
1644-
}
1645-
}
1646-
16471659
// The waiting task has been queued on the channel,
16481660
// there were pending tasks so it will be woken up eventually.
16491661
#ifdef __ARM_ARCH_7K__
@@ -1654,7 +1666,7 @@ static void swift_taskGroup_waitAllImpl(
16541666
#endif /* __ARM_ARCH_7K__ */
16551667

16561668
case PollStatus::Error:
1657-
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAll found error, waiting task = %p, body error = %p, status:%s",
1669+
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl Error, waiting task = %p, body error = %p, status:%s",
16581670
waitingTask, bodyError, group->statusString().c_str());
16591671
#if SWIFT_TASK_GROUP_BODY_THROWN_ERROR_WINS
16601672
if (bodyError) {
@@ -1676,17 +1688,10 @@ static void swift_taskGroup_waitAllImpl(
16761688
return waitingTask->runInFullyEstablishedContext();
16771689

16781690
case PollStatus::Empty:
1679-
/// Anything else than a "MustWait" can be treated as a successful poll.
1680-
/// Only if there are in flight pending tasks do we need to wait after all.
1681-
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAll %s, waiting task = %p, status:%s",
1682-
polled.status == TaskGroupBase::PollStatus::Empty ? "empty" : "success",
1683-
waitingTask, group->statusString().c_str());
1684-
1685-
16861691
case PollStatus::Success:
16871692
/// Anything else than a "MustWait" can be treated as a successful poll.
16881693
/// Only if there are in flight pending tasks do we need to wait after all.
1689-
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAll %s, waiting task = %p, status:%s",
1694+
SWIFT_TASK_GROUP_DEBUG_LOG(group, "waitAllImpl %s, waiting task = %p, status:%s",
16901695
polled.status == TaskGroupBase::PollStatus::Empty ? "empty" : "success",
16911696
waitingTask, group->statusString().c_str());
16921697

@@ -1717,10 +1722,13 @@ bool DiscardingTaskGroup::offerBodyError(SwiftError* _Nonnull bodyError) {
17171722
return true;
17181723
}
17191724

1720-
PollResult TaskGroupBase::waitAll(AsyncTask *waitingTask) {
1725+
/// Must be called while holding the `taskGroup.lock`!
1726+
/// This is because the discarding task group still has some follow-up operations that must
1727+
/// be performed atomically after this operation sometimes, so we cannot unlock inside `waitAll` itself.
1728+
PollResult TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask) {
17211729
lock(); // TODO: remove group lock, and use status for synchronization
1722-
1723-
SWIFT_TASK_GROUP_DEBUG_LOG(this, "waitAll, status = %s", statusString().c_str());
1730+
1731+
SWIFT_TASK_GROUP_DEBUG_LOG(this, "waitAll, bodyError = %p, status = %s", bodyError, statusString().c_str());
17241732
PollResult result = PollResult::getEmpty(this->successType);
17251733
result.status = PollStatus::Empty;
17261734
result.storage = nullptr;
@@ -1764,6 +1772,18 @@ PollResult TaskGroupBase::waitAll(AsyncTask *waitingTask) {
17641772
}
17651773

17661774
// ==== 2) Add to wait queue -------------------------------------------------
1775+
1776+
// ---- 2.1) Discarding task group may need to story the bodyError before we park
1777+
if (bodyError && isDiscardingResults()) {
1778+
auto discardingGroup = asDiscardingImpl(this);
1779+
bool storedBodyError = discardingGroup->offerBodyError(bodyError);
1780+
if (storedBodyError) {
1781+
SWIFT_TASK_GROUP_DEBUG_LOG(
1782+
this, "waitAll, stored error thrown by with...Group body, error = %p",
1783+
bodyError);
1784+
}
1785+
}
1786+
17671787
auto waitHead = waitQueue.load(std::memory_order_acquire);
17681788
_swift_tsan_release(static_cast<Job *>(waitingTask));
17691789
while (true) {

0 commit comments

Comments
 (0)