Skip to content

Commit e0cf003

Browse files
authored
[URLSession] Do not crash for unsupported URLs (#3154)
* [URLSession] Do not crash for unsupported URLs * [Tests] Add new URLSession tests to allTests array
1 parent 62cc39c commit e0cf003

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

Sources/FoundationNetworking/URLSession/URLSessionTask.swift

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ open class URLSessionTask : NSObject, NSCopying {
115115
fileprivate var _protocolStorage: ProtocolState = .toBeCreated
116116
internal var _lastCredentialUsedFromStorageDuringAuthentication: (protectionSpace: URLProtectionSpace, credential: URLCredential)?
117117

118-
private var _protocolClass: URLProtocol.Type {
118+
private var _protocolClass: URLProtocol.Type? {
119119
guard let request = currentRequest else { fatalError("A protocol class was requested, but we do not have a current request") }
120120
let protocolClasses = session.configuration.protocolClasses ?? []
121121
if let urlProtocolClass = URLProtocol.getProtocolClass(protocols: protocolClasses, request: request) {
@@ -128,15 +128,19 @@ open class URLSessionTask : NSObject, NSCopying {
128128
return urlProtocol
129129
}
130130
}
131-
132-
fatalError("Couldn't find a protocol appropriate for request: \(request)")
131+
return nil
133132
}
134133

135134
func _getProtocol(_ callback: @escaping (URLProtocol?) -> Void) {
136135
_protocolLock.lock() // Must be balanced below, before we call out ⬇
137136

138137
switch _protocolStorage {
139138
case .toBeCreated:
139+
guard let protocolClass = self._protocolClass else {
140+
_protocolLock.unlock() // Balances above ⬆
141+
callback(nil)
142+
break
143+
}
140144
if let cache = session.configuration.urlCache, let me = self as? URLSessionDataTask {
141145
let bag: Bag<(URLProtocol?) -> Void> = Bag()
142146
bag.values.append(callback)
@@ -145,11 +149,11 @@ open class URLSessionTask : NSObject, NSCopying {
145149
_protocolLock.unlock() // Balances above ⬆
146150

147151
cache.getCachedResponse(for: me) { (response) in
148-
let urlProtocol = self._protocolClass.init(task: self, cachedResponse: response, client: nil)
152+
let urlProtocol = protocolClass.init(task: self, cachedResponse: response, client: nil)
149153
self._satisfyProtocolRequest(with: urlProtocol)
150154
}
151155
} else {
152-
let urlProtocol = _protocolClass.init(task: self, cachedResponse: nil, client: nil)
156+
let urlProtocol = protocolClass.init(task: self, cachedResponse: nil, client: nil)
153157
_protocolStorage = .existing(urlProtocol)
154158
_protocolLock.unlock() // Balances above ⬆
155159

Tests/Foundation/Tests/TestURLSession.swift

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,54 @@ class TestURLSession: LoopbackServerTest {
331331
waitForExpectations(timeout: 12)
332332
}
333333

334+
func test_unhandledURLProtocol() {
335+
let urlString = "foobar://127.0.0.1:\(TestURLSession.serverPort)/Nepal"
336+
let url = URL(string: urlString)!
337+
let session = URLSession(configuration: URLSessionConfiguration.default,
338+
delegate: nil,
339+
delegateQueue: nil)
340+
let completionExpectation = expectation(description: "GET \(urlString): Unsupported URL error")
341+
let task = session.dataTask(with: url) { (data, response, _error) in
342+
XCTAssertNil(data)
343+
XCTAssertNil(response)
344+
let error = _error as? URLError
345+
XCTAssertNotNil(error)
346+
XCTAssertEqual(error?.code, .unsupportedURL)
347+
completionExpectation.fulfill()
348+
}
349+
task.resume()
350+
351+
waitForExpectations(timeout: 5) { error in
352+
XCTAssertNil(error)
353+
XCTAssertEqual((task.error as? URLError)?.code, .unsupportedURL)
354+
}
355+
}
356+
357+
func test_requestToNilURL() {
358+
let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/Nepal"
359+
let url = URL(string: urlString)!
360+
let session = URLSession(configuration: URLSessionConfiguration.default,
361+
delegate: nil,
362+
delegateQueue: nil)
363+
let completionExpectation = expectation(description: "DataTask with nil URL: Unsupported URL error")
364+
var request = URLRequest(url: url)
365+
request.url = nil
366+
let task = session.dataTask(with: request) { (data, response, _error) in
367+
XCTAssertNil(data)
368+
XCTAssertNil(response)
369+
let error = _error as? URLError
370+
XCTAssertNotNil(error)
371+
XCTAssertEqual(error?.code, .unsupportedURL)
372+
completionExpectation.fulfill()
373+
}
374+
task.resume()
375+
376+
waitForExpectations(timeout: 5) { error in
377+
XCTAssertNil(error)
378+
XCTAssertEqual((task.error as? URLError)?.code, .unsupportedURL)
379+
}
380+
}
381+
334382
func test_suspendResumeTask() throws {
335383
let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/get"
336384
let url = try XCTUnwrap(URL(string: urlString))
@@ -1819,6 +1867,8 @@ class TestURLSession: LoopbackServerTest {
18191867
("test_taskError", test_taskError),
18201868
("test_taskCopy", test_taskCopy),
18211869
("test_cancelTask", test_cancelTask),
1870+
("test_unhandledURLProtocol", test_unhandledURLProtocol),
1871+
("test_requestToNilURL", test_requestToNilURL),
18221872
/* ⚠️ */ ("test_suspendResumeTask", testExpectedToFail(test_suspendResumeTask, "Occasionally breaks")),
18231873
("test_taskTimeout", test_taskTimeout),
18241874
("test_verifyRequestHeaders", test_verifyRequestHeaders),

0 commit comments

Comments
 (0)