@@ -15,6 +15,7 @@ import class Foundation.JSONDecoder
15
15
import class Foundation. NSError
16
16
import class Foundation. OperationQueue
17
17
import struct Foundation. URL
18
+ import struct Foundation. UUID
18
19
import TSCBasic
19
20
20
21
#if canImport(Glibc)
@@ -35,7 +36,7 @@ public enum HTTPClientError: Error, Equatable {
35
36
36
37
// MARK: - HTTPClient
37
38
38
- public struct HTTPClient {
39
+ public struct HTTPClient : Cancellable {
39
40
public typealias Configuration = HTTPClientConfiguration
40
41
public typealias Request = HTTPClientRequest
41
42
public typealias Response = HTTPClientResponse
@@ -47,9 +48,12 @@ public struct HTTPClient {
47
48
private let underlying : Handler
48
49
49
50
/// DispatchSemaphore to restrict concurrent operations on manager.
50
- private let concurrencySemaphore : DispatchSemaphore
51
- /// OperationQueue to park pending requests
52
- private let requestsQueue : OperationQueue
51
+ private let concurrencySemaphore : DispatchSemaphore
52
+ /// OperationQueue to park pending requests
53
+ private let requestsQueue : OperationQueue
54
+
55
+ // tracks outstanding requests for cancellation
56
+ private var outstandingRequests = ThreadSafeKeyValueStore < UUID , ( completion: CompletionHandler , progress: ProgressHandler ? , queue: DispatchQueue ) > ( )
53
57
54
58
// static to share across instances of the http client
55
59
private static var hostsErrorsLock = Lock ( )
@@ -76,7 +80,12 @@ public struct HTTPClient {
76
80
/// - observabilityScope: the observability scope to emit diagnostics on
77
81
/// - progress: A progress handler to handle progress for example for downloads
78
82
/// - completion: A completion handler to be notified of the completion of the request.
79
- public func execute( _ request: Request , observabilityScope: ObservabilityScope ? = nil , progress: ProgressHandler ? = nil , completion: @escaping CompletionHandler ) {
83
+ public func execute(
84
+ _ request: Request ,
85
+ observabilityScope: ObservabilityScope ? = nil ,
86
+ progress: ProgressHandler ? = nil ,
87
+ completion: @escaping CompletionHandler
88
+ ) {
80
89
// merge configuration
81
90
var request = request
82
91
if request. options. callbackQueue == nil {
@@ -107,7 +116,9 @@ public struct HTTPClient {
107
116
request. headers. add ( name: " Authorization " , value: authorization)
108
117
}
109
118
// execute
110
- let callbackQueue = request. options. callbackQueue ?? self . configuration. callbackQueue
119
+ guard let callbackQueue = request. options. callbackQueue else {
120
+ return completion ( . failure( InternalError ( " unknown callback queue " ) ) )
121
+ }
111
122
self . _execute (
112
123
request: request,
113
124
requestNumber: 0 ,
@@ -129,13 +140,44 @@ public struct HTTPClient {
129
140
)
130
141
}
131
142
132
- private func _execute( request: Request , requestNumber: Int , observabilityScope: ObservabilityScope ? , progress: ProgressHandler ? , completion: @escaping CompletionHandler ) {
143
+ /// Cancel any outstanding requests
144
+ public func cancel( deadline: DispatchTime ) {
145
+ let outstanding = self . outstandingRequests. get ( )
146
+ self . outstandingRequests. clear ( )
147
+ for (completion, _, queue) in outstanding. values {
148
+ queue. async {
149
+ completion ( . failure( CancellationError ( ) ) )
150
+ }
151
+ }
152
+ }
153
+
154
+ private func _execute(
155
+ request: Request ,
156
+ requestNumber: Int ,
157
+ observabilityScope: ObservabilityScope ? ,
158
+ progress: ProgressHandler ? ,
159
+ completion: @escaping CompletionHandler
160
+ ) {
161
+ // records outstanding requests for cancellation purposes
162
+ guard let callbackQueue = request. options. callbackQueue else {
163
+ return completion ( . failure( InternalError ( " unknown callback queue " ) ) )
164
+ }
165
+ let requestKey = UUID ( )
166
+ self . outstandingRequests [ requestKey] = ( completion: completion, progress: progress, queue: callbackQueue)
167
+
133
168
// wrap completion handler with concurrency control cleanup
134
169
let originalCompletion = completion
135
170
let completion : CompletionHandler = { result in
136
171
// free concurrency control semaphore
137
172
self . concurrencySemaphore. signal ( )
138
- originalCompletion ( result)
173
+ // cancellation support
174
+ // if the callback is no longer on the pending lists it has been canceled already
175
+ if let ( callback, _, queue) = self . outstandingRequests [ requestKey] {
176
+ // remove from outstanding requests
177
+ self . outstandingRequests [ requestKey] = nil
178
+ // call back on the request queue
179
+ queue. async { callback ( result) }
180
+ }
139
181
}
140
182
141
183
// we must not block the calling thread (for concurrency control) so nesting this in a queue
@@ -172,9 +214,11 @@ public struct HTTPClient {
172
214
// handle retry strategy
173
215
if let retryDelay = self . shouldRetry ( response: response, request: request, requestNumber: requestNumber) {
174
216
observabilityScope? . emit ( warning: " \( request. url) failed, retrying in \( retryDelay) " )
175
- // free concurrency control semaphore, since we re-submitting the request with the original completion handler
176
- // using the wrapped completion handler may lead to starving the mac concurrent requests
217
+ // free concurrency control semaphore and outstanding request,
218
+ // since we re-submitting the request with the original completion handler
219
+ // using the wrapped completion handler may lead to starving the max concurrent requests
177
220
self . concurrencySemaphore. signal ( )
221
+ self . outstandingRequests [ requestKey] = nil
178
222
// TODO: dedicated retry queue?
179
223
return self . configuration. callbackQueue. asyncAfter ( deadline: . now( ) + retryDelay) {
180
224
self . _execute ( request: request, requestNumber: requestNumber + 1 , observabilityScope: observabilityScope, progress: progress, completion: originalCompletion)
0 commit comments