Skip to content

Commit a023bd3

Browse files
authored
add progress tracking to http client (#3255)
motivation: * safer handling of large responses * prepare to consolidate Downloader from TSC and http client from Basics changes: * add progress support to HTTPclient::execute and URLSession based impl * add HTTPClientRequest::Options::maxResponseSize and logic to fail the request if response is too large * add authorization support to HTTPClientRequest::Options and HTTPClient::Configuration * refactor call-sites to use new ways * add and adopt tests
1 parent f2146ac commit a023bd3

File tree

9 files changed

+359
-177
lines changed

9 files changed

+359
-177
lines changed

Sources/Basics/ConcurrencyHelpers.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
See http://swift.org/CONTRIBUTORS.txt for Swift project authors
77
*/
88

9-
import TSCBasic
109
import class Foundation.ProcessInfo
10+
import TSCBasic
1111

1212
/// Thread-safe dictionary like structure
1313
public final class ThreadSafeKeyValueStore<Key, Value> where Key: Hashable {
@@ -37,6 +37,12 @@ public final class ThreadSafeKeyValueStore<Key, Value> where Key: Hashable {
3737
}
3838
}
3939

40+
public func removeValue(forKey key: Key) -> Value? {
41+
self.lock.withLock {
42+
self.underlying.removeValue(forKey: key)
43+
}
44+
}
45+
4046
public func clear() {
4147
self.lock.withLock {
4248
self.underlying.removeAll()

Sources/Basics/HTPClient+URLSession.swift

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1+
/*
2+
This source file is part of the Swift.org open source project
13

4+
Copyright (c) 2020 Apple Inc. and the Swift project authors
5+
Licensed under Apache License v2.0 with Runtime Library Exception
6+
7+
See http://swift.org/LICENSE.txt for license information
8+
See http://swift.org/CONTRIBUTORS.txt for Swift project authors
9+
*/
210

311
import Foundation
412
import struct TSCUtility.Versioning
@@ -8,28 +16,84 @@ import struct TSCUtility.Versioning
816
import FoundationNetworking
917
#endif
1018

11-
public struct URLSessionHTTPClient: HTTPClientProtocol {
19+
public final class URLSessionHTTPClient: NSObject, HTTPClientProtocol {
1220
private let configuration: URLSessionConfiguration
21+
private let delegateQueue: OperationQueue
22+
private var session: URLSession!
23+
private var tasks = ThreadSafeKeyValueStore<Int, DataTask>()
1324

1425
public init(configuration: URLSessionConfiguration = .default) {
1526
self.configuration = configuration
27+
self.delegateQueue = OperationQueue()
28+
self.delegateQueue.name = "org.swift.swiftpm.urlsession-http-client"
29+
self.delegateQueue.maxConcurrentOperationCount = 1
30+
super.init()
31+
self.session = URLSession(configuration: self.configuration, delegate: self, delegateQueue: self.delegateQueue)
1632
}
1733

18-
public func execute(_ request: HTTPClient.Request, callback: @escaping (Result<HTTPClient.Response, Error>) -> Void) {
19-
let session = URLSession(configuration: self.configuration)
20-
let task = session.dataTask(with: request.urlRequest()) { data, response, error in
21-
if let error = error {
22-
callback(.failure(error))
23-
} else if let response = response as? HTTPURLResponse {
24-
callback(.success(response.response(body: data)))
25-
} else {
26-
callback(.failure(HTTPClientError.invalidResponse))
27-
}
28-
}
34+
public func execute(_ request: HTTPClient.Request, progress: ProgressHandler?, completion: @escaping CompletionHandler) {
35+
let task = self.session.dataTask(with: request.urlRequest())
36+
self.tasks[task.taskIdentifier] = DataTask(task: task, progressHandler: progress, completionHandler: completion)
2937
task.resume()
3038
}
3139
}
3240

41+
extension URLSessionHTTPClient: URLSessionDataDelegate {
42+
public func urlSession(_ session: URLSession,
43+
dataTask: URLSessionDataTask,
44+
didReceive response: URLResponse,
45+
completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) {
46+
guard let task = self.tasks[dataTask.taskIdentifier] else {
47+
return completionHandler(.cancel)
48+
}
49+
task.response = response as? HTTPURLResponse
50+
task.expectedContentLength = response.expectedContentLength
51+
task.progressHandler?(0, response.expectedContentLength)
52+
completionHandler(.allow)
53+
}
54+
55+
public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
56+
guard let task = self.tasks[dataTask.taskIdentifier] else {
57+
return
58+
}
59+
if task.buffer != nil {
60+
task.buffer?.append(data)
61+
} else {
62+
task.buffer = data
63+
}
64+
task.progressHandler?(Int64(task.buffer?.count ?? 0), task.expectedContentLength) // safe since created in the line above
65+
}
66+
67+
public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
68+
guard let task = self.tasks.removeValue(forKey: task.taskIdentifier) else {
69+
return
70+
}
71+
if let error = error {
72+
task.completionHandler(.failure(error))
73+
} else if let response = task.response {
74+
task.completionHandler(.success(response.response(body: task.buffer)))
75+
} else {
76+
task.completionHandler(.failure(HTTPClientError.invalidResponse))
77+
}
78+
}
79+
80+
class DataTask {
81+
let task: URLSessionDataTask
82+
let completionHandler: CompletionHandler
83+
let progressHandler: ProgressHandler?
84+
85+
var response: HTTPURLResponse?
86+
var expectedContentLength: Int64?
87+
var buffer: Data?
88+
89+
init(task: URLSessionDataTask, progressHandler: ProgressHandler?, completionHandler: @escaping CompletionHandler) {
90+
self.task = task
91+
self.progressHandler = progressHandler
92+
self.completionHandler = completionHandler
93+
}
94+
}
95+
}
96+
3397
extension HTTPClient.Request {
3498
func urlRequest() -> URLRequest {
3599
var request = URLRequest(url: self.url)

Sources/Basics/HTTPClient.swift

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,24 @@ import CRT
2525
#endif
2626

2727
public protocol HTTPClientProtocol {
28-
func execute(_ request: HTTPClientRequest, callback: @escaping (Result<HTTPClientResponse, Error>) -> Void)
28+
typealias ProgressHandler = (_ bytesReceived: Int64, _ totalBytes: Int64?) -> Void
29+
typealias CompletionHandler = (Result<HTTPClientResponse, Error>) -> Void
30+
31+
/// Execute an HTTP request asynchronously
32+
///
33+
/// - Parameters:
34+
/// - request: The `HTTPClientRequest` to perform.
35+
/// - callback: A closure to be notified of the completion of the request.
36+
func execute(_ request: HTTPClientRequest,
37+
progress: ProgressHandler?,
38+
completion: @escaping CompletionHandler)
2939
}
3040

3141
public enum HTTPClientError: Error, Equatable {
3242
case invalidResponse
3343
case badResponseStatusCode(Int)
3444
case circuitBreakerTriggered
45+
case responseTooLarge(Int64)
3546
}
3647

3748
// MARK: - HTTPClient
@@ -40,7 +51,7 @@ public struct HTTPClient: HTTPClientProtocol {
4051
public typealias Configuration = HTTPClientConfiguration
4152
public typealias Request = HTTPClientRequest
4253
public typealias Response = HTTPClientResponse
43-
public typealias Handler = (Request, @escaping (Result<Response, Error>) -> Void) -> Void
54+
public typealias Handler = (Request, ProgressHandler?, @escaping (Result<Response, Error>) -> Void) -> Void
4455

4556
public var configuration: HTTPClientConfiguration
4657
private let diagnosticsEngine: DiagnosticsEngine?
@@ -57,7 +68,7 @@ public struct HTTPClient: HTTPClientProtocol {
5768
self.underlying = handler ?? URLSessionHTTPClient().execute
5869
}
5970

60-
public func execute(_ request: Request, callback: @escaping (Result<Response, Error>) -> Void) {
71+
public func execute(_ request: Request, progress: ProgressHandler? = nil, completion: @escaping CompletionHandler) {
6172
// merge configuration
6273
var request = request
6374
if request.options.callbackQueue == nil {
@@ -72,6 +83,9 @@ public struct HTTPClient: HTTPClientProtocol {
7283
if request.options.timeout == nil {
7384
request.options.timeout = self.configuration.requestTimeout
7485
}
86+
if request.options.authorizationProvider == nil {
87+
request.options.authorizationProvider = self.configuration.authorizationProvider
88+
}
7589
// add additional headers
7690
if let additionalHeaders = self.configuration.requestHeaders {
7791
additionalHeaders.forEach {
@@ -81,43 +95,68 @@ public struct HTTPClient: HTTPClientProtocol {
8195
if request.options.addUserAgent, !request.headers.contains("User-Agent") {
8296
request.headers.add(name: "User-Agent", value: "SwiftPackageManager/\(SwiftVersion.currentVersion.displayString)")
8397
}
98+
if let authorization = request.options.authorizationProvider?(request.url), !request.headers.contains("Authorization") {
99+
request.headers.add(name: "Authorization", value: authorization)
100+
}
84101
// execute
85-
self._execute(request: request, requestNumber: 0) { result in
86-
let callbackQueue = request.options.callbackQueue ?? self.configuration.callbackQueue
87-
callbackQueue.async {
88-
callback(result)
102+
let callbackQueue = request.options.callbackQueue ?? self.configuration.callbackQueue
103+
self._execute(
104+
request: request, requestNumber: 0,
105+
progress: progress.map { handler in
106+
{ received, expected in
107+
callbackQueue.async {
108+
handler(received, expected)
109+
}
110+
}
111+
},
112+
completion: { result in
113+
callbackQueue.async {
114+
completion(result)
115+
}
89116
}
90-
}
117+
)
91118
}
92119

93-
private func _execute(request: Request, requestNumber: Int, callback: @escaping (Result<Response, Error>) -> Void) {
120+
private func _execute(request: Request, requestNumber: Int, progress: ProgressHandler?, completion: @escaping CompletionHandler) {
94121
if self.shouldCircuitBreak(request: request) {
95122
diagnosticsEngine?.emit(warning: "Circuit breaker triggered for \(request.url)")
96-
return callback(.failure(HTTPClientError.circuitBreakerTriggered))
123+
return completion(.failure(HTTPClientError.circuitBreakerTriggered))
97124
}
98125

99-
self.underlying(request) { result in
100-
switch result {
101-
case .failure(let error):
102-
callback(.failure(error))
103-
case .success(let response):
104-
// record host errors for circuit breaker
105-
self.recordErrorIfNecessary(response: response, request: request)
106-
// handle retry strategy
107-
if let retryDelay = self.shouldRetry(response: response, request: request, requestNumber: requestNumber) {
108-
self.diagnosticsEngine?.emit(warning: "\(request.url) failed, retrying in \(retryDelay)")
109-
// TODO: dedicated retry queue?
110-
return self.configuration.callbackQueue.asyncAfter(deadline: .now() + retryDelay) {
111-
self._execute(request: request, requestNumber: requestNumber + 1, callback: callback)
126+
self.underlying(
127+
request,
128+
{ received, expected in
129+
if let max = request.options.maximumResponseSizeInBytes {
130+
guard received < max else {
131+
// FIXME: cancel the request?
132+
return completion(.failure(HTTPClientError.responseTooLarge(received)))
112133
}
113134
}
114-
// check for valid response codes
115-
if let validResponseCodes = request.options.validResponseCodes, !validResponseCodes.contains(response.statusCode) {
116-
return callback(.failure(HTTPClientError.badResponseStatusCode(response.statusCode)))
135+
progress?(received, expected)
136+
},
137+
{ result in
138+
switch result {
139+
case .failure(let error):
140+
completion(.failure(error))
141+
case .success(let response):
142+
// record host errors for circuit breaker
143+
self.recordErrorIfNecessary(response: response, request: request)
144+
// handle retry strategy
145+
if let retryDelay = self.shouldRetry(response: response, request: request, requestNumber: requestNumber) {
146+
self.diagnosticsEngine?.emit(warning: "\(request.url) failed, retrying in \(retryDelay)")
147+
// TODO: dedicated retry queue?
148+
return self.configuration.callbackQueue.asyncAfter(deadline: .now() + retryDelay) {
149+
self._execute(request: request, requestNumber: requestNumber + 1, progress: progress, completion: completion)
150+
}
151+
}
152+
// check for valid response codes
153+
if let validResponseCodes = request.options.validResponseCodes, !validResponseCodes.contains(response.statusCode) {
154+
return completion(.failure(HTTPClientError.badResponseStatusCode(response.statusCode)))
155+
}
156+
completion(.success(response))
117157
}
118-
callback(.success(response))
119158
}
120-
}
159+
)
121160
}
122161

123162
private func shouldRetry(response: Response, request: Request, requestNumber: Int) -> DispatchTimeInterval? {
@@ -179,39 +218,43 @@ public struct HTTPClient: HTTPClientProtocol {
179218
}
180219

181220
public extension HTTPClient {
182-
func head(_ url: URL, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), callback: @escaping (Result<Response, Error>) -> Void) {
183-
self.execute(Request(method: .head, url: url, headers: headers, body: nil, options: options), callback: callback)
221+
func head(_ url: URL, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), completion: @escaping (Result<Response, Error>) -> Void) {
222+
self.execute(Request(method: .head, url: url, headers: headers, body: nil, options: options), completion: completion)
184223
}
185224

186-
func get(_ url: URL, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), callback: @escaping (Result<Response, Error>) -> Void) {
187-
self.execute(Request(method: .get, url: url, headers: headers, body: nil, options: options), callback: callback)
225+
func get(_ url: URL, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), completion: @escaping (Result<Response, Error>) -> Void) {
226+
self.execute(Request(method: .get, url: url, headers: headers, body: nil, options: options), completion: completion)
188227
}
189228

190-
func put(_ url: URL, body: Data?, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), callback: @escaping (Result<Response, Error>) -> Void) {
191-
self.execute(Request(method: .put, url: url, headers: headers, body: body, options: options), callback: callback)
229+
func put(_ url: URL, body: Data?, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), completion: @escaping (Result<Response, Error>) -> Void) {
230+
self.execute(Request(method: .put, url: url, headers: headers, body: body, options: options), completion: completion)
192231
}
193232

194-
func post(_ url: URL, body: Data?, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), callback: @escaping (Result<Response, Error>) -> Void) {
195-
self.execute(Request(method: .post, url: url, headers: headers, body: body, options: options), callback: callback)
233+
func post(_ url: URL, body: Data?, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), completion: @escaping (Result<Response, Error>) -> Void) {
234+
self.execute(Request(method: .post, url: url, headers: headers, body: body, options: options), completion: completion)
196235
}
197236

198-
func delete(_ url: URL, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), callback: @escaping (Result<Response, Error>) -> Void) {
199-
self.execute(Request(method: .delete, url: url, headers: headers, body: nil, options: options), callback: callback)
237+
func delete(_ url: URL, headers: HTTPClientHeaders = .init(), options: Request.Options = .init(), completion: @escaping (Result<Response, Error>) -> Void) {
238+
self.execute(Request(method: .delete, url: url, headers: headers, body: nil, options: options), completion: completion)
200239
}
201240
}
202241

203242
// MARK: - HTTPClientConfiguration
204243

244+
public typealias HTTPClientAuthorizationProvider = (URL) -> String?
245+
205246
public struct HTTPClientConfiguration {
206247
public var requestHeaders: HTTPClientHeaders?
207248
public var requestTimeout: DispatchTimeInterval?
249+
public var authorizationProvider: HTTPClientAuthorizationProvider?
208250
public var retryStrategy: HTTPClientRetryStrategy?
209251
public var circuitBreakerStrategy: HTTPClientCircuitBreakerStrategy?
210252
public var callbackQueue: DispatchQueue
211253

212254
public init() {
213255
self.requestHeaders = .none
214256
self.requestTimeout = .none
257+
self.authorizationProvider = .none
215258
self.retryStrategy = .none
216259
self.circuitBreakerStrategy = .none
217260
self.callbackQueue = .global()
@@ -259,6 +302,8 @@ public struct HTTPClientRequest {
259302
public var addUserAgent: Bool
260303
public var validResponseCodes: [Int]?
261304
public var timeout: DispatchTimeInterval?
305+
public var maximumResponseSizeInBytes: Int64?
306+
public var authorizationProvider: HTTPClientAuthorizationProvider?
262307
public var retryStrategy: HTTPClientRetryStrategy?
263308
public var circuitBreakerStrategy: HTTPClientCircuitBreakerStrategy?
264309
public var callbackQueue: DispatchQueue?
@@ -267,6 +312,8 @@ public struct HTTPClientRequest {
267312
self.addUserAgent = true
268313
self.validResponseCodes = .none
269314
self.timeout = .none
315+
self.maximumResponseSizeInBytes = .none
316+
self.authorizationProvider = .none
270317
self.retryStrategy = .none
271318
self.circuitBreakerStrategy = .none
272319
self.callbackQueue = .none

0 commit comments

Comments
 (0)