Skip to content

Commit fa0f2b2

Browse files
authored
Merge pull request #75274 from mikeash/fix-unsafe-continuation-validation
[Concurrency] Fix unsafe continuation validation when a continued task has been destroyed.
2 parents 22db563 + c1772eb commit fa0f2b2

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

stdlib/public/Concurrency/Task.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,7 @@ enum class State : uint8_t { Uninitialized, On, Off };
13971397
static std::atomic<State> CurrentState;
13981398

13991399
static LazyMutex ActiveContinuationsLock;
1400-
static Lazy<std::unordered_set<ContinuationAsyncContext *>> ActiveContinuations;
1400+
static Lazy<std::unordered_set<AsyncTask *>> ActiveContinuations;
14011401

14021402
static bool isEnabled() {
14031403
auto state = CurrentState.load(std::memory_order_relaxed);
@@ -1410,39 +1410,39 @@ static bool isEnabled() {
14101410
return state == State::On;
14111411
}
14121412

1413-
static void init(ContinuationAsyncContext *context) {
1413+
static void init(AsyncTask *task) {
14141414
if (!isEnabled())
14151415
return;
14161416

14171417
LazyMutex::ScopedLock guard(ActiveContinuationsLock);
1418-
auto result = ActiveContinuations.get().insert(context);
1418+
auto result = ActiveContinuations.get().insert(task);
14191419
auto inserted = std::get<1>(result);
14201420
if (!inserted)
14211421
swift_Concurrency_fatalError(
14221422
0,
1423-
"Initializing continuation context %p that was already initialized.\n",
1424-
context);
1423+
"Initializing continuation for task %p that was already initialized.\n",
1424+
task);
14251425
}
14261426

1427-
static void willResume(ContinuationAsyncContext *context) {
1427+
static void willResume(AsyncTask *task) {
14281428
if (!isEnabled())
14291429
return;
14301430

14311431
LazyMutex::ScopedLock guard(ActiveContinuationsLock);
1432-
auto removed = ActiveContinuations.get().erase(context);
1432+
auto removed = ActiveContinuations.get().erase(task);
14331433
if (!removed)
1434-
swift_Concurrency_fatalError(0,
1435-
"Resuming continuation context %p that was not awaited "
1436-
"(may have already been resumed).\n",
1437-
context);
1434+
swift_Concurrency_fatalError(
1435+
0,
1436+
"Resuming continuation for task %p that is not awaited "
1437+
"(may have already been resumed).\n",
1438+
task);
14381439
}
14391440

14401441
} // namespace continuationChecking
14411442

14421443
SWIFT_CC(swift)
14431444
static AsyncTask *swift_continuation_initImpl(ContinuationAsyncContext *context,
14441445
AsyncContinuationFlags flags) {
1445-
continuationChecking::init(context);
14461446
context->Flags = ContinuationAsyncContext::FlagsType();
14471447
if (flags.canThrow()) context->Flags.setCanThrow(true);
14481448
if (flags.isExecutorSwitchForced())
@@ -1479,6 +1479,7 @@ static AsyncTask *swift_continuation_initImpl(ContinuationAsyncContext *context,
14791479
task->ResumeContext = context;
14801480
task->ResumeTask = context->ResumeParent;
14811481

1482+
continuationChecking::init(task);
14821483
concurrency::trace::task_continuation_init(task, context);
14831484

14841485
return task;
@@ -1593,8 +1594,6 @@ static void swift_continuation_awaitImpl(ContinuationAsyncContext *context) {
15931594

15941595
static void resumeTaskAfterContinuation(AsyncTask *task,
15951596
ContinuationAsyncContext *context) {
1596-
continuationChecking::willResume(context);
1597-
15981597
auto &sync = context->AwaitSynchronization;
15991598

16001599
auto status = sync.load(std::memory_order_acquire);
@@ -1644,13 +1643,15 @@ static void resumeTaskAfterContinuation(AsyncTask *task,
16441643

16451644
SWIFT_CC(swift)
16461645
static void swift_continuation_resumeImpl(AsyncTask *task) {
1646+
continuationChecking::willResume(task);
16471647
auto context = static_cast<ContinuationAsyncContext*>(task->ResumeContext);
16481648
concurrency::trace::task_continuation_resume(context, false);
16491649
resumeTaskAfterContinuation(task, context);
16501650
}
16511651

16521652
SWIFT_CC(swift)
16531653
static void swift_continuation_throwingResumeImpl(AsyncTask *task) {
1654+
continuationChecking::willResume(task);
16541655
auto context = static_cast<ContinuationAsyncContext*>(task->ResumeContext);
16551656
concurrency::trace::task_continuation_resume(context, false);
16561657
resumeTaskAfterContinuation(task, context);
@@ -1660,6 +1661,7 @@ static void swift_continuation_throwingResumeImpl(AsyncTask *task) {
16601661
SWIFT_CC(swift)
16611662
static void swift_continuation_throwingResumeWithErrorImpl(AsyncTask *task,
16621663
/* +1 */ SwiftError *error) {
1664+
continuationChecking::willResume(task);
16631665
auto context = static_cast<ContinuationAsyncContext*>(task->ResumeContext);
16641666
concurrency::trace::task_continuation_resume(context, true);
16651667
context->ErrorResult = error;

0 commit comments

Comments
 (0)