Skip to content

Concurrency: Factor atomic operations in Task.sleep() into a Sendable wrapper #72062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 57 additions & 48 deletions stdlib/public/Concurrency/TaskSleep.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,55 @@ extension Task where Success == Never, Failure == Never {
}
}

/// A simple wrapper for a pointer to heap allocated storage of a `SleepState`
/// value. This wrapper is `Sendable` because it facilitates atomic load and
/// exchange operations on the underlying storage. However, this wrapper is also
/// _unsafe_ because the owner must manually deallocate the token once it is no
/// longer needed.
struct UnsafeSleepStateToken: @unchecked Sendable {
let wordPtr: UnsafeMutablePointer<Builtin.Word>

/// Allocates the underlying storage and sets the value to `.notStarted`.
init() {
wordPtr = .allocate(capacity: 1)
Builtin.atomicstore_seqcst_Word(
wordPtr._rawValue, SleepState.notStarted.word._builtinWordValue)
}

/// Atomically loads the current state.
func load() -> SleepState {
return SleepState(word: Builtin.atomicload_seqcst_Word(wordPtr._rawValue))
}

/// Attempts to atomically set the stored value to `desired` if the current
/// value is equal to `expected`. Returns true if the exchange was successful.
func exchange(expected: SleepState, desired: SleepState) -> Bool {
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
wordPtr._rawValue,
expected.word._builtinWordValue,
desired.word._builtinWordValue)
return Bool(_builtinBooleanLiteral: won)
}

/// Deallocates the underlying storage.
func deallocate() {
wordPtr.deallocate()
}
}

/// Called when the sleep(nanoseconds:) operation woke up without being
/// canceled.
static func onSleepWake(
_ wordPtr: UnsafeMutablePointer<Builtin.Word>
) {
static func onSleepWake(_ token: UnsafeSleepStateToken) {
while true {
let state = SleepState(loading: wordPtr)
let state = token.load()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

late to the party here -- I was looking into this recently and it turned out we didn't properly handle nonisolated(unsafe) for local variables (heh). That was fixed since then so I wondered if we could just nonisolated(unsafe) the wordPtr now and avoid the wrapper entirely.

Although avoiding the alloc entirely would be even better...! So thanks for investigating that too.

switch state {
case .notStarted:
fatalError("Cannot wake before we even started")

case .activeContinuation(let continuation):
// We have an active continuation, so try to transition to the
// "finished" state.
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
wordPtr._rawValue,
state.word._builtinWordValue,
SleepState.finished.word._builtinWordValue)
if Bool(_builtinBooleanLiteral: won) {
if token.exchange(expected: state, desired: .finished) {
// The sleep finished, so invoke the continuation: we're done.
continuation.resume()
return
Expand All @@ -137,9 +167,9 @@ extension Task where Success == Never, Failure == Never {

case .cancelled:
// The task was cancelled, which means the continuation was
// called by the cancellation handler. We need to deallocate the flag
// word, because it was left over for this task to complete.
wordPtr.deallocate()
// called by the cancellation handler. We need to deallocate the token
// because it was left over for this task to complete.
token.deallocate()
return

case .cancelledBeforeStarted:
Expand All @@ -151,20 +181,14 @@ extension Task where Success == Never, Failure == Never {

/// Called when the sleep(nanoseconds:) operation has been canceled before
/// the sleep completed.
static func onSleepCancel(
_ wordPtr: UnsafeMutablePointer<Builtin.Word>
) {
static func onSleepCancel(_ token: UnsafeSleepStateToken) {
while true {
let state = SleepState(loading: wordPtr)
let state = token.load()
switch state {
case .notStarted:
// We haven't started yet, so try to transition to the cancelled-before
// started state.
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
wordPtr._rawValue,
state.word._builtinWordValue,
SleepState.cancelledBeforeStarted.word._builtinWordValue)
if Bool(_builtinBooleanLiteral: won) {
if token.exchange(expected: state, desired: .cancelledBeforeStarted) {
return
}

Expand All @@ -174,11 +198,7 @@ extension Task where Success == Never, Failure == Never {
case .activeContinuation(let continuation):
// We have an active continuation, so try to transition to the
// "cancelled" state.
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
wordPtr._rawValue,
state.word._builtinWordValue,
SleepState.cancelled.word._builtinWordValue)
if Bool(_builtinBooleanLiteral: won) {
if token.exchange(expected: state, desired: .cancelled) {
// We recorded the task cancellation before the sleep finished, so
// invoke the continuation with the cancellation error.
continuation.resume(throwing: _Concurrency.CancellationError())
Expand All @@ -203,33 +223,22 @@ extension Task where Success == Never, Failure == Never {
///
/// This function doesn't block the underlying thread.
public static func sleep(nanoseconds duration: UInt64) async throws {
// Allocate storage for the storage word.
let wordPtr = UnsafeMutablePointer<Builtin.Word>.allocate(capacity: 1)

// Initialize the flag word to "not started", which means the continuation
// has neither been created nor completed.
Builtin.atomicstore_seqcst_Word(
wordPtr._rawValue, SleepState.notStarted.word._builtinWordValue)
// Create a token which will initially have the value "not started", which
// means the continuation has neither been created nor completed.
let token = UnsafeSleepStateToken()

do {
// Install a cancellation handler to resume the continuation by
// throwing CancellationError.
try await withTaskCancellationHandler {
let _: () = try await withUnsafeThrowingContinuation { continuation in
while true {
let state = SleepState(loading: wordPtr)
let state = token.load()
switch state {
case .notStarted:
// The word that describes the active continuation state.
let continuationWord =
SleepState.activeContinuation(continuation).word

// Try to swap in the continuation word.
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
wordPtr._rawValue,
state.word._builtinWordValue,
continuationWord._builtinWordValue)
if !Bool(_builtinBooleanLiteral: won) {
// Try to swap in the continuation state.
let newState = SleepState.activeContinuation(continuation)
if !token.exchange(expected: state, desired: newState) {
// Keep trying!
continue
}
Expand All @@ -243,7 +252,7 @@ extension Task where Success == Never, Failure == Never {
addPendingGroupTaskUnconditionally: false,
isDiscardingTask: false)
let (sleepTask, _) = Builtin.createAsyncTask(sleepTaskFlags) {
onSleepWake(wordPtr)
onSleepWake(token)
}
_enqueueJobGlobalWithDelay(
duration, Builtin.convertTaskToJob(sleepTask))
Expand All @@ -264,12 +273,12 @@ extension Task where Success == Never, Failure == Never {
}
}
} onCancel: {
onSleepCancel(wordPtr)
onSleepCancel(token)
}

// Determine whether we got cancelled before we even started.
let cancelledBeforeStarted: Bool
switch SleepState(loading: wordPtr) {
switch token.load() {
case .notStarted, .activeContinuation, .cancelled:
fatalError("Invalid state for non-cancelled sleep task")

Expand All @@ -282,7 +291,7 @@ extension Task where Success == Never, Failure == Never {

// We got here without being cancelled, so deallocate the storage for
// the flag word and continuation.
wordPtr.deallocate()
token.deallocate()

// If we got cancelled before we even started, through the cancellation
// error now.
Expand Down
31 changes: 10 additions & 21 deletions stdlib/public/Concurrency/TaskSleepDuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,22 @@ extension Task where Success == Never, Failure == Never {
tolerance: Duration?,
clock: _ClockID
) async throws {
// Allocate storage for the storage word.
let wordPtr = UnsafeMutablePointer<Builtin.Word>.allocate(capacity: 1)

// Initialize the flag word to "not started", which means the continuation
// has neither been created nor completed.
Builtin.atomicstore_seqcst_Word(
wordPtr._rawValue, SleepState.notStarted.word._builtinWordValue)
// Create a token which will initially have the value "not started", which
// means the continuation has neither been created nor completed.
let token = UnsafeSleepStateToken()

do {
// Install a cancellation handler to resume the continuation by
// throwing CancellationError.
try await withTaskCancellationHandler {
let _: () = try await withUnsafeThrowingContinuation { continuation in
while true {
let state = SleepState(loading: wordPtr)
let state = token.load()
switch state {
case .notStarted:
// The word that describes the active continuation state.
let continuationWord =
SleepState.activeContinuation(continuation).word

// Try to swap in the continuation word.
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
wordPtr._rawValue,
state.word._builtinWordValue,
continuationWord._builtinWordValue)
if !Bool(_builtinBooleanLiteral: won) {
let newState = SleepState.activeContinuation(continuation)
if !token.exchange(expected: state, desired: newState) {
// Keep trying!
continue
}
Expand All @@ -61,7 +50,7 @@ extension Task where Success == Never, Failure == Never {
addPendingGroupTaskUnconditionally: false,
isDiscardingTask: false)
let (sleepTask, _) = Builtin.createAsyncTask(sleepTaskFlags) {
onSleepWake(wordPtr)
onSleepWake(token)
}
let toleranceSeconds: Int64
let toleranceNanoseconds: Int64
Expand Down Expand Up @@ -94,12 +83,12 @@ extension Task where Success == Never, Failure == Never {
}
}
} onCancel: {
onSleepCancel(wordPtr)
onSleepCancel(token)
}

// Determine whether we got cancelled before we even started.
let cancelledBeforeStarted: Bool
switch SleepState(loading: wordPtr) {
switch token.load() {
case .notStarted, .activeContinuation, .cancelled:
fatalError("Invalid state for non-cancelled sleep task")

Expand All @@ -112,7 +101,7 @@ extension Task where Success == Never, Failure == Never {

// We got here without being cancelled, so deallocate the storage for
// the flag word and continuation.
wordPtr.deallocate()
token.deallocate()

// If we got cancelled before we even started, through the cancellation
// error now.
Expand Down