Skip to content

[URLSession] Do not crash for unsupported URLs #3154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions Sources/FoundationNetworking/URLSession/URLSessionTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ open class URLSessionTask : NSObject, NSCopying {
fileprivate var _protocolStorage: ProtocolState = .toBeCreated
internal var _lastCredentialUsedFromStorageDuringAuthentication: (protectionSpace: URLProtectionSpace, credential: URLCredential)?

private var _protocolClass: URLProtocol.Type {
private var _protocolClass: URLProtocol.Type? {
guard let request = currentRequest else { fatalError("A protocol class was requested, but we do not have a current request") }
let protocolClasses = session.configuration.protocolClasses ?? []
if let urlProtocolClass = URLProtocol.getProtocolClass(protocols: protocolClasses, request: request) {
Expand All @@ -128,15 +128,19 @@ open class URLSessionTask : NSObject, NSCopying {
return urlProtocol
}
}

fatalError("Couldn't find a protocol appropriate for request: \(request)")
return nil
}

func _getProtocol(_ callback: @escaping (URLProtocol?) -> Void) {
_protocolLock.lock() // Must be balanced below, before we call out ⬇

switch _protocolStorage {
case .toBeCreated:
guard let protocolClass = self._protocolClass else {
_protocolLock.unlock() // Balances above ⬆
callback(nil)
break
}
if let cache = session.configuration.urlCache, let me = self as? URLSessionDataTask {
let bag: Bag<(URLProtocol?) -> Void> = Bag()
bag.values.append(callback)
Expand All @@ -145,11 +149,11 @@ open class URLSessionTask : NSObject, NSCopying {
_protocolLock.unlock() // Balances above ⬆

cache.getCachedResponse(for: me) { (response) in
let urlProtocol = self._protocolClass.init(task: self, cachedResponse: response, client: nil)
let urlProtocol = protocolClass.init(task: self, cachedResponse: response, client: nil)
self._satisfyProtocolRequest(with: urlProtocol)
}
} else {
let urlProtocol = _protocolClass.init(task: self, cachedResponse: nil, client: nil)
let urlProtocol = protocolClass.init(task: self, cachedResponse: nil, client: nil)
_protocolStorage = .existing(urlProtocol)
_protocolLock.unlock() // Balances above ⬆

Expand Down
50 changes: 50 additions & 0 deletions Tests/Foundation/Tests/TestURLSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,54 @@ class TestURLSession: LoopbackServerTest {
waitForExpectations(timeout: 12)
}

func test_unhandledURLProtocol() {
let urlString = "foobar://127.0.0.1:\(TestURLSession.serverPort)/Nepal"
let url = URL(string: urlString)!
let session = URLSession(configuration: URLSessionConfiguration.default,
delegate: nil,
delegateQueue: nil)
let completionExpectation = expectation(description: "GET \(urlString): Unsupported URL error")
let task = session.dataTask(with: url) { (data, response, _error) in
XCTAssertNil(data)
XCTAssertNil(response)
let error = _error as? URLError
XCTAssertNotNil(error)
XCTAssertEqual(error?.code, .unsupportedURL)
completionExpectation.fulfill()
}
task.resume()

waitForExpectations(timeout: 5) { error in
XCTAssertNil(error)
XCTAssertEqual((task.error as? URLError)?.code, .unsupportedURL)
}
}

func test_requestToNilURL() {
let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/Nepal"
let url = URL(string: urlString)!
let session = URLSession(configuration: URLSessionConfiguration.default,
delegate: nil,
delegateQueue: nil)
let completionExpectation = expectation(description: "DataTask with nil URL: Unsupported URL error")
var request = URLRequest(url: url)
request.url = nil
let task = session.dataTask(with: request) { (data, response, _error) in
XCTAssertNil(data)
XCTAssertNil(response)
let error = _error as? URLError
XCTAssertNotNil(error)
XCTAssertEqual(error?.code, .unsupportedURL)
completionExpectation.fulfill()
}
task.resume()

waitForExpectations(timeout: 5) { error in
XCTAssertNil(error)
XCTAssertEqual((task.error as? URLError)?.code, .unsupportedURL)
}
}

func test_suspendResumeTask() throws {
let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/get"
let url = try XCTUnwrap(URL(string: urlString))
Expand Down Expand Up @@ -1819,6 +1867,8 @@ class TestURLSession: LoopbackServerTest {
("test_taskError", test_taskError),
("test_taskCopy", test_taskCopy),
("test_cancelTask", test_cancelTask),
("test_unhandledURLProtocol", test_unhandledURLProtocol),
("test_requestToNilURL", test_requestToNilURL),
/* ⚠️ */ ("test_suspendResumeTask", testExpectedToFail(test_suspendResumeTask, "Occasionally breaks")),
("test_taskTimeout", test_taskTimeout),
("test_verifyRequestHeaders", test_verifyRequestHeaders),
Expand Down