Skip to content

Commit d93c022

Browse files
authored
Merge pull request #61076 from mikeash/unsafe-continuation-validation
[Concurrency] Add an environment variable to validate unchecked continuation usage.
2 parents 9b99ed7 + afc5116 commit d93c022

File tree

6 files changed

+99
-2
lines changed

6 files changed

+99
-2
lines changed

include/swift/Runtime/EnvironmentVariables.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ extern swift::once_t initializeToken;
4040
// Concurrency library can call.
4141
SWIFT_RUNTIME_STDLIB_SPI bool concurrencyEnableJobDispatchIntegration();
4242

43+
// Wrapper around SWIFT_DEBUG_VALIDATE_UNCHECKED_CONTINUATIONS that the
44+
// Concurrency library can call.
45+
SWIFT_RUNTIME_STDLIB_SPI bool concurrencyValidateUncheckedContinuations();
46+
4347
} // end namespace environment
4448
} // end namespace runtime
4549
} // end namespace Swift

stdlib/public/Concurrency/Task.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@
3030
#include "swift/ABI/Task.h"
3131
#include "swift/ABI/TaskLocal.h"
3232
#include "swift/ABI/TaskOptions.h"
33+
#include "swift/Basic/Lazy.h"
3334
#include "swift/Runtime/Concurrency.h"
35+
#include "swift/Runtime/EnvironmentVariables.h"
3436
#include "swift/Runtime/HeapObject.h"
3537
#include "swift/Threading/Mutex.h"
3638
#include <atomic>
3739
#include <new>
40+
#include <unordered_set>
3841

3942
#if SWIFT_CONCURRENCY_ENABLE_DISPATCH
4043
#include <dispatch/dispatch.h>
@@ -1238,9 +1241,59 @@ swift_task_enqueueTaskOnExecutorImpl(AsyncTask *task, ExecutorRef executor)
12381241
task->flagAsAndEnqueueOnExecutor(executor);
12391242
}
12401243

1244+
namespace continuationChecking {
1245+
1246+
enum class State : uint8_t { Uninitialized, On, Off };
1247+
1248+
static std::atomic<State> CurrentState;
1249+
1250+
static LazyMutex ActiveContinuationsLock;
1251+
static Lazy<std::unordered_set<ContinuationAsyncContext *>> ActiveContinuations;
1252+
1253+
static bool isEnabled() {
1254+
auto state = CurrentState.load(std::memory_order_relaxed);
1255+
if (state == State::Uninitialized) {
1256+
bool enabled =
1257+
runtime::environment::concurrencyValidateUncheckedContinuations();
1258+
state = enabled ? State::On : State::Off;
1259+
CurrentState.store(state, std::memory_order_relaxed);
1260+
}
1261+
return state == State::On;
1262+
}
1263+
1264+
static void init(ContinuationAsyncContext *context) {
1265+
if (!isEnabled())
1266+
return;
1267+
1268+
LazyMutex::ScopedLock guard(ActiveContinuationsLock);
1269+
auto result = ActiveContinuations.get().insert(context);
1270+
auto inserted = std::get<1>(result);
1271+
if (!inserted)
1272+
swift_Concurrency_fatalError(
1273+
0,
1274+
"Initializing continuation context %p that was already initialized.\n",
1275+
context);
1276+
}
1277+
1278+
static void willResume(ContinuationAsyncContext *context) {
1279+
if (!isEnabled())
1280+
return;
1281+
1282+
LazyMutex::ScopedLock guard(ActiveContinuationsLock);
1283+
auto removed = ActiveContinuations.get().erase(context);
1284+
if (!removed)
1285+
swift_Concurrency_fatalError(0,
1286+
"Resuming continuation context %p that was not awaited "
1287+
"(may have already been resumed).\n",
1288+
context);
1289+
}
1290+
1291+
} // namespace continuationChecking
1292+
12411293
SWIFT_CC(swift)
12421294
static AsyncTask *swift_continuation_initImpl(ContinuationAsyncContext *context,
12431295
AsyncContinuationFlags flags) {
1296+
continuationChecking::init(context);
12441297
context->Flags = ContinuationAsyncContext::FlagsType();
12451298
if (flags.canThrow()) context->Flags.setCanThrow(true);
12461299
if (flags.isExecutorSwitchForced())
@@ -1341,6 +1394,8 @@ static void swift_continuation_awaitImpl(ContinuationAsyncContext *context) {
13411394

13421395
static void resumeTaskAfterContinuation(AsyncTask *task,
13431396
ContinuationAsyncContext *context) {
1397+
continuationChecking::willResume(context);
1398+
13441399
auto &sync = context->AwaitSynchronization;
13451400
auto status = sync.load(std::memory_order_acquire);
13461401
assert(status != ContinuationStatus::Resumed &&

stdlib/public/Concurrency/TaskPrivate.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ namespace swift {
4040
// If this is enabled, tests with `swift_task_debug_log` requirement can run.
4141
#if 0
4242
#define SWIFT_TASK_DEBUG_LOG(fmt, ...) \
43-
fprintf(stderr, "[%#lx] [%s:%d](%s) " fmt "\n", \
44-
(unsigned long)Thread::current()::platformThreadId(), __FILE__, \
43+
fprintf(stderr, "[%#lx] [%s:%d](%s) " fmt "\n", \
44+
(unsigned long)Thread::current().platformThreadId(), __FILE__, \
4545
__LINE__, __FUNCTION__, __VA_ARGS__)
4646
#else
4747
#define SWIFT_TASK_DEBUG_LOG(fmt, ...) (void)0

stdlib/public/runtime/EnvironmentVariables.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,7 @@ SWIFT_RUNTIME_STDLIB_SPI bool concurrencyEnableJobDispatchIntegration() {
246246
return runtime::environment::
247247
SWIFT_ENABLE_ASYNC_JOB_DISPATCH_INTEGRATION();
248248
}
249+
250+
SWIFT_RUNTIME_STDLIB_SPI bool concurrencyValidateUncheckedContinuations() {
251+
return runtime::environment::SWIFT_DEBUG_VALIDATE_UNCHECKED_CONTINUATIONS();
252+
}

stdlib/public/runtime/EnvironmentVariables.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ VARIABLE(SWIFT_DEBUG_ENABLE_COW_CHECKS, bool, false,
5050
VARIABLE(SWIFT_ENABLE_ASYNC_JOB_DISPATCH_INTEGRATION, bool, true,
5151
"Enable use of dispatch_async_swift_job when available.")
5252

53+
VARIABLE(SWIFT_DEBUG_VALIDATE_UNCHECKED_CONTINUATIONS, bool, false,
54+
"Check for and error on double-calls of unchecked continuations.")
55+
5356
#if defined(__APPLE__) && defined(__MACH__)
5457

5558
VARIABLE(SWIFT_DEBUG_VALIDATE_SHARED_CACHE_PROTOCOL_CONFORMANCES, bool, false,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift -Xfrontend -disable-availability-checking -parse-as-library %s -o %t/a.out
3+
// RUN: %target-codesign %t/a.out
4+
// RUN: env %env-SWIFT_DEBUG_VALIDATE_UNCHECKED_CONTINUATIONS=1 %target-run %t/a.out
5+
6+
// REQUIRES: executable_test
7+
// REQUIRES: concurrency
8+
// REQUIRES: concurrency_runtime
9+
// UNSUPPORTED: back_deployment_runtime
10+
// UNSUPPORTED: use_os_stdlib
11+
12+
import StdlibUnittest
13+
14+
@main struct Main {
15+
static func main() async {
16+
let tests = TestSuite("ContinuationValidation")
17+
18+
if #available(SwiftStdlib 5.1, *) {
19+
tests.test("trap on double resume of unchecked continuation") {
20+
expectCrashLater(withMessage: "may have already been resumed")
21+
22+
await withUnsafeContinuation { c in
23+
c.resume(returning: ())
24+
c.resume(returning: ())
25+
}
26+
}
27+
}
28+
29+
await runAllTestsAsync()
30+
}
31+
}

0 commit comments

Comments
 (0)