Skip to content

Commit 3d17cad

Browse files
committed
Implement request cancellation
When receiving a `CancellationNotification`, we cancel the task that handles the request with that ID. This will cause `cancel_notification` to be sent to sourcekitd or a `CancellationNotification` to be sent to `clangd`, which ultimately cancels the request. rdar://117492860
1 parent 87dd95e commit 3d17cad

File tree

14 files changed

+229
-45
lines changed

14 files changed

+229
-45
lines changed

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ let package = Package(
103103
name: "LanguageServerProtocol",
104104
dependencies: [
105105
"LSPLogging",
106+
"SKSupport",
106107
.product(name: "SwiftToolsSupport-auto", package: "swift-tools-support-core"),
107108
],
108109
exclude: ["CMakeLists.txt"]

Sources/LSPTestSupport/Assertions.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public func assertNoThrow<T>(
3636
}
3737
}
3838

39-
/// Same as `XCTAssertThrows` but executes the trailing closure.
39+
/// Same as `XCTAssertThrows` but allows the expression to be async
4040
public func assertThrowsError<T>(
4141
_ expression: @autoclosure () async throws -> T,
4242
_ message: @autoclosure () -> String = "",

Sources/LanguageServerProtocol/AsyncQueue.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public final class AsyncQueue<TaskMetadata: DependencyTracker> {
8888
let throwingTask = asyncThrowing(priority: priority, metadata: metadata, operation: operation)
8989
return Task {
9090
do {
91-
return try await throwingTask.value
91+
return try await throwingTask.valuePropagatingCancellation
9292
} catch {
9393
// We know this can never happen because `operation` does not throw.
9494
preconditionFailure("Executing a task threw an error even though the operation did not throw")
@@ -141,7 +141,7 @@ public final class AsyncQueue<TaskMetadata: DependencyTracker> {
141141

142142
/// Convenience overloads for serial queues.
143143
extension AsyncQueue where TaskMetadata == Serial {
144-
/// Same as ``async(priority:operation:)`` but specialized for serial queues
144+
/// Same as ``async(priority:operation:)`` but specialized for serial queues
145145
/// that don't specify any metadata.
146146
@discardableResult
147147
public func async<Success: Sendable>(
@@ -151,7 +151,7 @@ extension AsyncQueue where TaskMetadata == Serial {
151151
return self.async(priority: priority, metadata: Serial(), operation: operation)
152152
}
153153

154-
/// Same as ``asyncThrowing(priority:metadata:operation:)`` but specialized
154+
/// Same as ``asyncThrowing(priority:metadata:operation:)`` but specialized
155155
/// for serial queues that don't specify any metadata.
156156
public func asyncThrowing<Success: Sendable>(
157157
priority: TaskPriority? = nil,

Sources/LanguageServerProtocol/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,8 @@ add_library(LanguageServerProtocol STATIC
138138
set_target_properties(LanguageServerProtocol PROPERTIES
139139
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
140140
target_link_libraries(LanguageServerProtocol PUBLIC
141+
LSPLogging
142+
SKSupport
141143
TSCBasic
142144
$<$<NOT:$<PLATFORM_ID:Darwin>>:swiftDispatch>
143145
$<$<NOT:$<PLATFORM_ID:Darwin>>:Foundation>)
144-
145-
target_link_libraries(LanguageServerProtocol PUBLIC
146-
LSPLogging
147-
)

Sources/LanguageServerProtocol/Connection.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
import Dispatch
14+
import SKSupport
1415

1516
/// An abstract connection, allow messages to be sent to a (potentially remote) `MessageHandler`.
1617
public protocol Connection: AnyObject {
@@ -144,3 +145,30 @@ extension LocalConnection: Connection {
144145
return id
145146
}
146147
}
148+
149+
extension Connection {
150+
/// Send the given request to the connection and await its result.
151+
///
152+
/// This method automatically sends a `CancelRequestNotification` to the
153+
/// connection if the task it is executing in is being cancelled.
154+
///
155+
/// - Warning: Because this message is `async`, it does not provide any ordering
156+
/// guarantees. If you need to gurantee that messages are sent in-order
157+
/// use the version with a completion handler.
158+
public func send<R: RequestType>(_ request: R) async throws -> R.Response {
159+
let requestIDWrapper = ThreadSafeBox<RequestID?>(initialValue: nil)
160+
try Task.checkCancellation()
161+
return try await withTaskCancellationHandler {
162+
try await withCheckedThrowingContinuation { continuation in
163+
let requestID = self.send(request) { result in
164+
continuation.resume(with: result)
165+
}
166+
requestIDWrapper.value = requestID
167+
}
168+
} onCancel: {
169+
if let requestID = requestIDWrapper.value {
170+
self.send(CancelRequestNotification(id: requestID))
171+
}
172+
}
173+
}
174+
}

Sources/SKSupport/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
add_library(SKSupport STATIC
33
BuildConfiguration.swift
44
ByteString.swift
5+
dlopen.swift
56
FileSystem.swift
67
LineTable.swift
78
Random.swift
89
Result.swift
9-
dlopen.swift)
10+
Task+ValuePropagatingCancellation.swift
11+
ThreadSafeBox.swift
12+
)
1013
set_target_properties(SKSupport PROPERTIES
1114
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
1215
target_link_libraries(SKSupport PRIVATE
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
public extension Task {
14+
/// Awaits the value of the result.
15+
///
16+
/// If the current task is cancelled, this will cancel the subtask as well.
17+
var valuePropagatingCancellation: Success {
18+
get async throws {
19+
try await withTaskCancellationHandler {
20+
return try await self.value
21+
} onCancel: {
22+
self.cancel()
23+
}
24+
}
25+
}
26+
}

Sources/SKSupport/ThreadSafeBox.swift

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import Foundation
14+
15+
extension NSLock {
16+
/// NOTE: Keep in sync with SwiftPM's 'Sources/Basics/NSLock+Extensions.swift'
17+
fileprivate func withLock<T>(_ body: () throws -> T) rethrows -> T {
18+
lock()
19+
defer { unlock() }
20+
return try body()
21+
}
22+
}
23+
24+
/// A thread safe container that contains a value of type `T`.
25+
public class ThreadSafeBox<T> {
26+
/// Lock guarding `_value`.
27+
private let lock = NSLock()
28+
29+
private var _value: T
30+
31+
public var value: T {
32+
get {
33+
return lock.withLock {
34+
return _value
35+
}
36+
}
37+
set {
38+
lock.withLock {
39+
_value = newValue
40+
}
41+
}
42+
}
43+
44+
public init(initialValue: T) {
45+
_value = initialValue
46+
}
47+
}

Sources/SourceKitD/SourceKitD.swift

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,26 +83,30 @@ extension SourceKitD {
8383
public func send(_ req: SKDRequestDictionary) async throws -> SKDResponseDictionary {
8484
logRequest(req)
8585

86-
let sourcekitdResponse: SKDResponse = await withCheckedContinuation { continuation in
87-
var handle: sourcekitd_request_handle_t? = nil
88-
89-
api.send_request(req.dict, &handle) { _resp in
90-
continuation.resume(returning: SKDResponse(_resp, sourcekitd: self))
86+
let handleWrapper = ThreadSafeBox<sourcekitd_request_handle_t?>(initialValue: nil)
87+
let sourcekitdResponse: SKDResponse = try await withTaskCancellationHandler {
88+
try Task.checkCancellation()
89+
return await withCheckedContinuation { continuation in
90+
var handle: sourcekitd_request_handle_t? = nil
91+
92+
api.send_request(req.dict, &handle) { _resp in
93+
continuation.resume(returning: SKDResponse(_resp, sourcekitd: self))
94+
}
95+
handleWrapper.value = handle
96+
}
97+
} onCancel: {
98+
if let handle = handleWrapper.value {
99+
api.cancel_request(handle)
91100
}
92101
}
102+
93103
logResponse(sourcekitdResponse)
94104

95105
guard let dict = sourcekitdResponse.value else {
96106
throw sourcekitdResponse.error!
97107
}
98108

99109
return dict
100-
101-
// FIXME: (async) Cancellation
102-
}
103-
104-
public func cancel(_ handle: sourcekitd_request_handle_t) {
105-
api.cancel_request(handle)
106110
}
107111
}
108112

Sources/SourceKitLSP/Clang/ClangLanguageServer.swift

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -327,17 +327,7 @@ actor ClangLanguageServerShim: ToolchainLanguageServer, MessageHandler {
327327
///
328328
/// The response of the request is returned asynchronously as the return value.
329329
func forwardRequestToClangd<R: RequestType>(_ request: R) async throws -> R.Response {
330-
try await withCheckedThrowingContinuation { continuation in
331-
_ = clangd.send(request) { result in
332-
switch result {
333-
case .success(let response):
334-
continuation.resume(returning: response)
335-
case .failure(let error):
336-
continuation.resume(throwing: error)
337-
}
338-
}
339-
}
340-
// FIXME: (async) Cancellation
330+
return try await clangd.send(request)
341331
}
342332

343333
func _crash() {

Sources/SourceKitLSP/SourceKitServer.swift

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,12 @@ public actor SourceKitServer {
262262
/// The actual semantic handling of the message happens off this queue.
263263
private let messageHandlingQueue = AsyncQueue<TaskMetadata>()
264264

265+
/// The queue on which we start and stop keeping track of cancellation.
266+
///
267+
/// Having a queue for this ensures that we started keeping track of a
268+
/// request's task before handling any cancellation request for it.
269+
private let cancellationMessageHandlingQueue = AsyncQueue<Serial>()
270+
265271
/// The connection to the editor.
266272
public let client: Connection
267273

@@ -305,6 +311,17 @@ public actor SourceKitServer {
305311
}
306312
}
307313

314+
/// The requests that we are currently handling.
315+
///
316+
/// Used to cancel the tasks if the client requests cancellation.
317+
private var inProgressRequests: [RequestID: Task<(), Never>] = [:]
318+
319+
/// - Note: Needed so we can set an in-progress request from a different
320+
/// isolation context.
321+
private func setInProgressRequest(for id: RequestID, task: Task<(), Never>?) {
322+
self.inProgressRequests[id] = task
323+
}
324+
308325
let fs: FileSystem
309326

310327
var onExit: () -> Void
@@ -400,12 +417,7 @@ public actor SourceKitServer {
400417

401418
/// Send the given request to the editor.
402419
public func sendRequestToClient<R: RequestType>(_ request: R) async throws -> R.Response {
403-
try await withCheckedThrowingContinuation { continuation in
404-
_ = client.send(request) { result in
405-
continuation.resume(with: result)
406-
}
407-
// FIXME: (async) Handle cancellation
408-
}
420+
return try await client.send(request)
409421
}
410422

411423
func toolchain(for uri: DocumentURI, _ language: Language) -> Toolchain? {
@@ -608,6 +620,13 @@ private func getNextNotificationIDForLogging() -> Int {
608620

609621
extension SourceKitServer: MessageHandler {
610622
public nonisolated func handle(_ params: some NotificationType, from clientID: ObjectIdentifier) {
623+
if let params = params as? CancelRequestNotification {
624+
// Request cancellation needs to be able to overtake any other message we
625+
// are currently handling. Ordering is not important here. We thus don't
626+
// need to execute it on `messageHandlingQueue`.
627+
self.cancelRequest(params)
628+
}
629+
611630
messageHandlingQueue.async(metadata: TaskMetadata(params)) {
612631
let notificationID = getNextNotificationIDForLogging()
613632

@@ -630,8 +649,6 @@ extension SourceKitServer: MessageHandler {
630649
switch notification.params {
631650
case let notification as InitializedNotification:
632651
self.clientInitialized(notification)
633-
case let notification as CancelRequestNotification:
634-
self.cancelRequest(notification)
635652
case let notification as ExitNotification:
636653
await self.exit(notification)
637654
case let notification as DidOpenTextDocumentNotification:
@@ -660,10 +677,21 @@ extension SourceKitServer: MessageHandler {
660677
from clientID: ObjectIdentifier,
661678
reply: @escaping (LSPResult<R.Response>) -> Void
662679
) {
663-
messageHandlingQueue.async(metadata: TaskMetadata(params)) {
680+
let task = messageHandlingQueue.async(metadata: TaskMetadata(params)) {
664681
await withLoggingScope("request-\(id)") {
665682
await self.handleImpl(params, id: id, from: clientID, reply: reply)
666683
}
684+
// We have handled the request and can't cancel it anymore.
685+
// Stop keeping track of it to free the memory.
686+
self.cancellationMessageHandlingQueue.async(priority: .background) {
687+
await self.setInProgressRequest(for: id, task: nil)
688+
}
689+
}
690+
// Keep track of the ID -> Task management with low priority. Once we cancel
691+
// a request, the cancellation task runs with a high priority and depends on
692+
// this task, which will elevate this task's priority.
693+
cancellationMessageHandlingQueue.async(priority: .background) {
694+
await self.setInProgressRequest(for: id, task: task)
667695
}
668696
}
669697

@@ -1077,8 +1105,18 @@ extension SourceKitServer {
10771105
// Nothing to do.
10781106
}
10791107

1080-
func cancelRequest(_ notification: CancelRequestNotification) {
1081-
// TODO: Implement cancellation
1108+
nonisolated func cancelRequest(_ notification: CancelRequestNotification) {
1109+
// Since the request is very cheap to execute and stops other requests
1110+
// from performing more work, we execute it with a high priority.
1111+
cancellationMessageHandlingQueue.async(priority: .high) {
1112+
guard let task = await self.inProgressRequests[notification.id] else {
1113+
logger.error(
1114+
"Cannot cancel request \(notification.id, privacy: .public) because it hasn't been scheduled for execution yet"
1115+
)
1116+
return
1117+
}
1118+
task.cancel()
1119+
}
10821120
}
10831121

10841122
/// The server is about to exit, and the server should flush any buffered state.

Sources/SourceKitLSP/Swift/CodeCompletionSession.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@ class CodeCompletionSession {
138138
return try await session.open(filterText: filterText, position: cursorPosition, in: snapshot, options: options)
139139
}
140140

141-
// FIXME: (async) Use valuePropagatingCancellation once we support cancellation
142-
return try await task.value
141+
return try await task.valuePropagatingCancellation
143142
}
144143

145144
// MARK: - Implementation

Sources/SourceKitLSP/Swift/SourceKitD+ResponseError.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ extension ResponseError {
1717
public init(_ value: SKDError) {
1818
switch value {
1919
case .requestCancelled:
20-
self = .serverCancelled
20+
self = .cancelled
2121
case .requestFailed(let desc):
2222
self = .unknown("sourcekitd request failed: \(desc)")
2323
case .requestInvalid(let desc):

0 commit comments

Comments
 (0)