Skip to content

Commit 18e8919

Browse files
committed
Task group children should be executed on parent task at time of await
1 parent af7738f commit 18e8919

File tree

5 files changed

+195
-67
lines changed

5 files changed

+195
-67
lines changed

stdlib/public/Concurrency/Task.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,9 @@ static AsyncTaskAndContext swift_task_create_commonImpl(
704704
}
705705
case TaskOptionRecordKind::RunInline: {
706706
runInlineOption = cast<RunInlineTaskOptionRecord>(option);
707+
// TODO (rokhinip): We seem to be creating runInline tasks like detached
708+
// tasks but they need to maintain the voucher and priority of calling
709+
// thread and therefore need to behave a bit more like SC child tasks.
707710
break;
708711
}
709712
}
@@ -974,6 +977,15 @@ static AsyncTaskAndContext swift_task_create_commonImpl(
974977
// Attach to the group, if needed.
975978
if (group) {
976979
swift_taskGroup_attachChild(group, task);
980+
#if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
981+
// We need to take a retain here to keep the child task for the task group
982+
// alive. In the non-task-to-thread model, we'd always take this retain
983+
// below since we'd enqueue the child task. But since we're not going to be
984+
// enqueueing the child task in this model, we need to take this +1 to
985+
// balance out the release that exists after the task group child task
986+
// creation
987+
swift_retain(task);
988+
#endif
977989
}
978990

979991
// If we're supposed to copy task locals, do so now.
@@ -988,6 +1000,9 @@ static AsyncTaskAndContext swift_task_create_commonImpl(
9881000

9891001
// If we're supposed to enqueue the task, do so now.
9901002
if (taskCreateFlags.enqueueJob()) {
1003+
#if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
1004+
assert(false && "Should not be enqueuing tasks in task-to-thread model");
1005+
#endif
9911006
swift_retain(task);
9921007
task->flagAsAndEnqueueOnExecutor(executor);
9931008
}

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 119 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,24 @@ class TaskGroupImpl: public TaskGroupTaskStatusRecord {
284284
};
285285

286286
private:
287-
#if !SWIFT_STDLIB_SINGLE_THREADED_CONCURRENCY
287+
#if SWIFT_STDLIB_SINGLE_THREADED_CONCURRENCY || SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
288+
// Synchronization is simple here. In a single threaded mode, all swift tasks
289+
// run on a single thread so no coordination is needed. In a task-to-thread
290+
// model, only the parent task which created the task group can
291+
//
292+
// (a) add child tasks to a group
293+
// (b) run the child tasks
294+
//
295+
// So we shouldn't need to worry about coordinating between child tasks and
296+
// parents in a task group
297+
void lock() const {}
298+
void unlock() const {}
299+
#else
288300
// TODO: move to lockless via the status atomic (make readyQueue an mpsc_queue_t<ReadyQueueItem>)
289301
mutable std::mutex mutex_;
290302

291303
void lock() const { mutex_.lock(); }
292304
void unlock() const { mutex_.unlock(); }
293-
#else
294-
void lock() const {}
295-
void unlock() const {}
296305
#endif
297306

298307
/// Used for queue management, counting number of waiting and ready tasks
@@ -437,6 +446,11 @@ class TaskGroupImpl: public TaskGroupTaskStatusRecord {
437446
/// or a `PollStatus::MustWait` result if there are tasks in flight
438447
/// and the waitingTask eventually be woken up by a completion.
439448
PollResult poll(AsyncTask *waitingTask);
449+
450+
private:
451+
// Enqueue the completed task onto ready queue if there are no waiting tasks
452+
// yet
453+
void enqueueCompletedTask(AsyncTask *completedTask, bool hadErrorResult);
440454
};
441455

442456
} // end anonymous namespace
@@ -569,6 +583,22 @@ static void fillGroupNextResult(TaskFutureWaitAsyncContext *context,
569583
}
570584
}
571585

586+
// TaskGroup is locked upon entry and exit
587+
void TaskGroupImpl::enqueueCompletedTask(AsyncTask *completedTask, bool hadErrorResult) {
588+
// Retain the task while it is in the queue;
589+
// it must remain alive until the task group is alive.
590+
swift_retain(completedTask);
591+
592+
auto readyItem = ReadyQueueItem::get(
593+
hadErrorResult ? ReadyStatus::Error : ReadyStatus::Success,
594+
completedTask
595+
);
596+
597+
assert(completedTask == readyItem.getTask());
598+
assert(readyItem.getTask()->isFuture());
599+
readyQueue.enqueue(readyItem);
600+
}
601+
572602
void TaskGroupImpl::offer(AsyncTask *completedTask, AsyncContext *context) {
573603
assert(completedTask);
574604
assert(completedTask->isFuture());
@@ -610,52 +640,58 @@ void TaskGroupImpl::offer(AsyncTask *completedTask, AsyncContext *context) {
610640
if (waitQueue.compare_exchange_strong(
611641
waitingTask, nullptr,
612642
/*success*/ std::memory_order_release,
613-
/*failure*/ std::memory_order_acquire) &&
614-
statusCompletePendingReadyWaiting(assumed)) {
615-
// Run the task.
616-
auto result = PollResult::get(completedTask, hadErrorResult);
643+
/*failure*/ std::memory_order_acquire)) {
644+
#if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
645+
// We have completed a child task in a task group task and we know
646+
// there is a waiting task who will reevaluate TaskGroupImpl::poll once
647+
// we return, by virtue of being in the task-to-thread model.
648+
// We want poll() to then satisfy the condition of having readyTasks()
649+
// that it can dequeue from the readyQueue so we need to enqueue our
650+
// completion.
651+
652+
// TODO (rokhinip): There's probably a more efficient way to deal with
653+
// this since the child task can directly offer the result to the
654+
// parent who will run next but that requires a fair bit of plumbing
655+
enqueueCompletedTask(completedTask, hadErrorResult);
656+
unlock(); // TODO: remove fragment lock, and use status for synchronization
657+
return;
658+
#else /* SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL */
659+
if (statusCompletePendingReadyWaiting(assumed)) {
660+
// Run the task.
661+
auto result = PollResult::get(completedTask, hadErrorResult);
617662

618-
unlock(); // TODO: remove fragment lock, and use status for synchronization
619-
620-
auto waitingContext =
621-
static_cast<TaskFutureWaitAsyncContext *>(
622-
waitingTask->ResumeContext);
663+
unlock(); // TODO: remove fragment lock, and use status for synchronization
623664

624-
fillGroupNextResult(waitingContext, result);
625-
detachChild(result.retainedTask);
665+
auto waitingContext =
666+
static_cast<TaskFutureWaitAsyncContext *>(
667+
waitingTask->ResumeContext);
626668

627-
_swift_tsan_acquire(static_cast<Job *>(waitingTask));
669+
fillGroupNextResult(waitingContext, result);
670+
detachChild(result.retainedTask);
628671

629-
// TODO: allow the caller to suggest an executor
630-
waitingTask->flagAsAndEnqueueOnExecutor(ExecutorRef::generic());
631-
return;
632-
} // else, try again
672+
_swift_tsan_acquire(static_cast<Job *>(waitingTask));
673+
// TODO: allow the caller to suggest an executor
674+
waitingTask->flagAsAndEnqueueOnExecutor(ExecutorRef::generic());
675+
return;
676+
} // else, try again
677+
#endif
678+
}
633679
}
634-
635680
llvm_unreachable("should have enqueued and returned.");
681+
} else {
682+
// ==== b) enqueue completion ------------------------------------------------
683+
//
684+
// else, no-one was waiting (yet), so we have to instead enqueue to the message
685+
// queue when a task polls during next() it will notice that we have a value
686+
// ready for it, and will process it immediately without suspending.
687+
assert(!waitQueue.load(std::memory_order_relaxed));
688+
689+
SWIFT_TASK_DEBUG_LOG("group has no waiting tasks, RETAIN and store ready task = %p",
690+
completedTask);
691+
enqueueCompletedTask(completedTask, hadErrorResult);
692+
unlock(); // TODO: remove fragment lock, and use status for synchronization
636693
}
637694

638-
// ==== b) enqueue completion ------------------------------------------------
639-
//
640-
// else, no-one was waiting (yet), so we have to instead enqueue to the message
641-
// queue when a task polls during next() it will notice that we have a value
642-
// ready for it, and will process it immediately without suspending.
643-
assert(!waitQueue.load(std::memory_order_relaxed));
644-
SWIFT_TASK_DEBUG_LOG("group has no waiting tasks, RETAIN and store ready task = %p",
645-
completedTask);
646-
// Retain the task while it is in the queue;
647-
// it must remain alive until the task group is alive.
648-
swift_retain(completedTask);
649-
650-
auto readyItem = ReadyQueueItem::get(
651-
hadErrorResult ? ReadyStatus::Error : ReadyStatus::Success,
652-
completedTask
653-
);
654-
655-
assert(completedTask == readyItem.getTask());
656-
assert(readyItem.getTask()->isFuture());
657-
readyQueue.enqueue(readyItem);
658-
unlock(); // TODO: remove fragment lock, and use status for synchronization
659695
return;
660696
}
661697

@@ -694,7 +730,7 @@ static void swift_taskGroup_wait_next_throwingImpl(
694730
OpaqueValue *resultPointer, SWIFT_ASYNC_CONTEXT AsyncContext *callerContext,
695731
TaskGroup *_group,
696732
ThrowingTaskFutureWaitContinuationFunction *resumeFunction,
697-
AsyncContext *rawContext) {
733+
AsyncContext *rawContext) SWIFT_OPTNONE {
698734
auto waitingTask = swift_task_getCurrent();
699735
waitingTask->ResumeTask = task_group_wait_resume_adapter;
700736
waitingTask->ResumeContext = rawContext;
@@ -719,9 +755,9 @@ static void swift_taskGroup_wait_next_throwingImpl(
719755
#ifdef __ARM_ARCH_7K__
720756
return workaround_function_swift_taskGroup_wait_next_throwingImpl(
721757
resultPointer, callerContext, _group, resumeFunction, rawContext);
722-
#else
758+
#else /* __ARM_ARCH_7K__ */
723759
return;
724-
#endif
760+
#endif /* __ARM_ARCH_7K__ */
725761

726762
case PollStatus::Empty:
727763
case PollStatus::Error:
@@ -739,16 +775,25 @@ static void swift_taskGroup_wait_next_throwingImpl(
739775
}
740776
}
741777

742-
PollResult TaskGroupImpl::poll(AsyncTask *waitingTask) {
778+
PollResult TaskGroupImpl::poll(AsyncTask *waitingTask) SWIFT_OPTNONE {
743779
lock(); // TODO: remove group lock, and use status for synchronization
744780
SWIFT_TASK_DEBUG_LOG("poll group = %p", this);
745-
auto assumed = statusMarkWaitingAssumeAcquire();
746781

747782
PollResult result;
748783
result.storage = nullptr;
749784
result.successType = nullptr;
750785
result.retainedTask = nullptr;
751786

787+
// Have we suspended the task?
788+
bool hasSuspended = false;
789+
bool haveRunOneChildTaskInline = false;
790+
791+
reevaluate_if_taskgroup_has_results:;
792+
auto assumed = statusMarkWaitingAssumeAcquire();
793+
if (haveRunOneChildTaskInline) {
794+
assert(assumed.readyTasks());
795+
}
796+
752797
// ==== 1) bail out early if no tasks are pending ----------------------------
753798
if (assumed.isEmpty()) {
754799
SWIFT_TASK_DEBUG_LOG("poll group = %p, group is empty, no pending tasks", this);
@@ -762,9 +807,6 @@ PollResult TaskGroupImpl::poll(AsyncTask *waitingTask) {
762807
return result;
763808
}
764809

765-
// Have we suspended the task?
766-
bool hasSuspended = false;
767-
768810
auto waitHead = waitQueue.load(std::memory_order_acquire);
769811

770812
// ==== 2) Ready task was polled, return with it immediately -----------------
@@ -779,17 +821,17 @@ PollResult TaskGroupImpl::poll(AsyncTask *waitingTask) {
779821
/*success*/ std::memory_order_relaxed,
780822
/*failure*/ std::memory_order_acquire)) {
781823

782-
// Success! We are allowed to poll.
783-
ReadyQueueItem item;
784-
bool taskDequeued = readyQueue.dequeue(item);
785-
assert(taskDequeued); (void) taskDequeued;
786-
787824
// We're going back to running the task, so if we suspended before,
788825
// we need to flag it as running again.
789826
if (hasSuspended) {
790827
waitingTask->flagAsRunning();
791828
}
792829

830+
// Success! We are allowed to poll.
831+
ReadyQueueItem item;
832+
bool taskDequeued = readyQueue.dequeue(item);
833+
assert(taskDequeued); (void) taskDequeued;
834+
793835
assert(item.getTask()->isFuture());
794836
auto futureFragment = item.getTask()->futureFragment();
795837

@@ -845,6 +887,28 @@ PollResult TaskGroupImpl::poll(AsyncTask *waitingTask) {
845887
/*success*/ std::memory_order_release,
846888
/*failure*/ std::memory_order_acquire)) {
847889
unlock(); // TODO: remove fragment lock, and use status for synchronization
890+
#if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL
891+
// The logic here is paired with the logic in TaskGroupImpl::offer. Once
892+
// we run the
893+
auto oldTask = _swift_task_clearCurrent();
894+
assert(oldTask == waitingTask);
895+
896+
auto childTask = getTaskRecord()->getFirstChild();
897+
assert(childTask != NULL);
898+
899+
SWIFT_TASK_DEBUG_LOG("[RunInline] Switching away from running %p to now running %p", oldTask, childTask);
900+
// Run the new task on the same thread now - this should run the new task to
901+
// completion. All swift tasks in task-to-thread model run on generic
902+
// executor
903+
swift_job_run(childTask, ExecutorRef::generic());
904+
haveRunOneChildTaskInline = true;
905+
906+
SWIFT_TASK_DEBUG_LOG("[RunInline] Switching back from running %p to now running %p", childTask, oldTask);
907+
// We are back to being the parent task and now that we've run the child
908+
// task, we should reevaluate parent task
909+
_swift_task_setCurrent(oldTask);
910+
goto reevaluate_if_taskgroup_has_results;
911+
#endif
848912
// no ready tasks, so we must wait.
849913
result.status = PollStatus::MustWait;
850914
_swift_task_clearCurrent();

stdlib/public/Concurrency/TaskGroup.swift

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,19 @@ public struct TaskGroup<ChildTaskResult: Sendable> {
238238
operation: __owned @Sendable @escaping () async -> ChildTaskResult
239239
) {
240240
#if compiler(>=5.5) && $BuiltinCreateAsyncTaskInGroup
241+
#if SWIFT_STDLIB_TASK_TO_THREAD_MODEL_CONCURRENCY
242+
let flags = taskCreateFlags(
243+
priority: priority, isChildTask: true, copyTaskLocals: false,
244+
inheritContext: false, enqueueJob: false,
245+
addPendingGroupTaskUnconditionally: true
246+
)
247+
#else
241248
let flags = taskCreateFlags(
242249
priority: priority, isChildTask: true, copyTaskLocals: false,
243250
inheritContext: false, enqueueJob: true,
244251
addPendingGroupTaskUnconditionally: true
245252
)
253+
#endif
246254

247255
// Create the task in this group.
248256
_ = Builtin.createAsyncTaskInGroup(flags, _group, operation)
@@ -272,12 +280,19 @@ public struct TaskGroup<ChildTaskResult: Sendable> {
272280
// the group is cancelled and is not accepting any new work
273281
return false
274282
}
275-
283+
#if SWIFT_STDLIB_TASK_TO_THREAD_MODEL_CONCURRENCY
284+
let flags = taskCreateFlags(
285+
priority: priority, isChildTask: true, copyTaskLocals: false,
286+
inheritContext: false, enqueueJob: false,
287+
addPendingGroupTaskUnconditionally: false
288+
)
289+
#else
276290
let flags = taskCreateFlags(
277291
priority: priority, isChildTask: true, copyTaskLocals: false,
278292
inheritContext: false, enqueueJob: true,
279293
addPendingGroupTaskUnconditionally: false
280294
)
295+
#endif
281296

282297
// Create the task in this group.
283298
_ = Builtin.createAsyncTaskInGroup(flags, _group, operation)
@@ -634,8 +649,8 @@ public struct ThrowingTaskGroup<ChildTaskResult: Sendable, Failure: Error> {
634649
/// guard let result = await group.nextResult() else {
635650
/// return // No task to wait on, which won't happen in this example.
636651
/// }
637-
///
638-
/// switch result {
652+
///
653+
/// switch result {
639654
/// case .success(let value): print(value)
640655
/// case .failure(let error): print("Failure: \(error)")
641656
/// }
@@ -810,15 +825,15 @@ extension ThrowingTaskGroup: AsyncSequence {
810825
/// group.addTask { 1 }
811826
/// group.addTask { throw SomeError }
812827
/// group.addTask { 2 }
813-
///
814-
/// do {
828+
///
829+
/// do {
815830
/// // Assuming the child tasks complete in order, this prints "1"
816831
/// // and then throws an error.
817832
/// for try await r in group { print(r) }
818833
/// } catch {
819834
/// // Resolve the error.
820835
/// }
821-
///
836+
///
822837
/// // Assuming the child tasks complete in order, this prints "2".
823838
/// for try await r in group { print(r) }
824839
///
@@ -847,7 +862,7 @@ extension ThrowingTaskGroup: AsyncSequence {
847862
/// this iterator is guaranteed to never produce more values.
848863
///
849864
/// For more information about the iteration order and semantics,
850-
/// see `ThrowingTaskGroup.next()`
865+
/// see `ThrowingTaskGroup.next()`
851866
///
852867
/// - Throws: The error thrown by the next child task that completes.
853868
///

0 commit comments

Comments
 (0)