Skip to content

URLSessionWebSocketTask.receive() not finishing if server closes connection without a close packet #4673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Sources/FoundationNetworking/URLSession/URLSessionTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,12 @@ open class URLSessionWebSocketTask : URLSessionTask {
}
}

open override var error: Error? {
didSet {
doPendingWork()
}
}

private var sendBuffer = [(Message, (Error?) -> Void)]()
private var receiveBuffer = [Message]()
private var receiveCompletionHandlers = [(Result<Message, Error>) -> Void]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol {
return .completeTask
}

override func completeTask() {
if let webSocketTask = task as? URLSessionWebSocketTask {
webSocketTask.close(code: .normalClosure, reason: nil)
}
super.completeTask()
}

func sendWebSocketData(_ data: Data, flags: _EasyHandle.WebSocketFlags) throws {
try easyHandle.sendWebSocketsData(data, flags: flags)
}
Expand Down
79 changes: 53 additions & 26 deletions Tests/Foundation/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -912,8 +912,28 @@ public class TestURLSessionServer: CustomStringConvertible {
"Connection: Upgrade"]

let expectFullRequestResponseTests: Bool
let sendClosePacket: Bool
let completeUpgrade: Bool

let uri = request.uri
if uri.count > "/web-socket/".count {
switch uri {
case "/web-socket":
expectFullRequestResponseTests = true
completeUpgrade = true
sendClosePacket = true
case "/web-socket/semi-abrupt-close":
expectFullRequestResponseTests = false
completeUpgrade = true
sendClosePacket = false
case "/web-socket/abrupt-close":
expectFullRequestResponseTests = false
completeUpgrade = false
sendClosePacket = false
default:
guard uri.count > "/web-socket/".count else {
NSLog("Expected Sec-WebSocket-Protocol")
throw InternalServerError.badHeaders
}
let expectedProtocol = String(uri.suffix(from: uri.index(uri.startIndex, offsetBy: "/web-socket/".count)))
guard let receivedProtocolStr = request.getHeader(for: "Sec-WebSocket-Protocol"),
expectedProtocol == receivedProtocolStr.components(separatedBy: ", ")[0] else {
Expand All @@ -922,10 +942,12 @@ public class TestURLSessionServer: CustomStringConvertible {
}
responseHeaders.append("Sec-WebSocket-Protocol: \(expectedProtocol)")
expectFullRequestResponseTests = false
} else {
expectFullRequestResponseTests = true
completeUpgrade = true
sendClosePacket = true
}

guard completeUpgrade else { return }

var upgradeResponse = _HTTPResponse(response: .SWITCHING_PROTOCOLS, headers: responseHeaders)
// Lacking an available SHA1 implementation, we'll only include this response for a well-known key
if "dGhlIHNhbXBsZSBub25jZQ==" == request.getHeader(for: "sec-websocket-key") {
Expand All @@ -940,10 +962,11 @@ public class TestURLSessionServer: CustomStringConvertible {
let closePayload = Data([UInt8(closeCode >> 8),
UInt8(closeCode & 0xFF)]) + closeReason

let pingPayload = "Hi".data(using: .utf8)!

if expectFullRequestResponseTests {
let stringPayload = "Hello".data(using: .utf8)!
let dataPayload = Data([0x20, 0x22, 0x10, 0x03])
let pingPayload = "Hi".data(using: .utf8)!

// Receive a string message
guard let stringFrame = try httpServer.tcpSocket.readData(),
Expand Down Expand Up @@ -981,32 +1004,36 @@ public class TestURLSessionServer: CustomStringConvertible {
}
// ... and pong it
try httpServer.tcpSocket.writeRawData(Data([0x8a, 0x00]))

// Send a ping
let sendPingFrame = Data([0x89, UInt8(pingPayload.count)]) + pingPayload
try httpServer.tcpSocket.writeRawData(sendPingFrame)
// ... and receive its pong
guard let pongFrame = try httpServer.tcpSocket.readData(),
pongFrame.count == (2 + 4 + pingPayload.count),
Data(pongFrame.prefix(2)) == Data([0x8a, (0x80 | UInt8(pingPayload.count))]),
try unmaskedPayload(from: pongFrame) == pingPayload else {
NSLog("Invalid pong frame")
throw InternalServerError.badBody
}

// Send a close
let sendCloseFrame = Data([0x88, UInt8(closePayload.count)]) + closePayload
try httpServer.tcpSocket.writeRawData(sendCloseFrame)
}

// Receive a close message
guard let closeFrame = try httpServer.tcpSocket.readData(),
closeFrame.count == (2 + 4 + closePayload.count),
Data(closeFrame.prefix(2)) == Data([0x88, (0x80 | UInt8(closePayload.count))]),
try unmaskedPayload(from: closeFrame) == closePayload else {
NSLog("Invalid close payload")
// Send a ping
let sendPingFrame = Data([0x89, UInt8(pingPayload.count)]) + pingPayload
try httpServer.tcpSocket.writeRawData(sendPingFrame)
// ... and receive its pong
guard let pongFrame = try httpServer.tcpSocket.readData(),
pongFrame.count == (2 + 4 + pingPayload.count),
Data(pongFrame.prefix(2)) == Data([0x8a, (0x80 | UInt8(pingPayload.count))]),
try unmaskedPayload(from: pongFrame) == pingPayload else {
NSLog("Invalid pong frame")
throw InternalServerError.badBody
}

if sendClosePacket {
if expectFullRequestResponseTests {
// Send a close
let sendCloseFrame = Data([0x88, UInt8(closePayload.count)]) + closePayload
try httpServer.tcpSocket.writeRawData(sendCloseFrame)
}

// Receive a close message
guard let closeFrame = try httpServer.tcpSocket.readData(),
closeFrame.count == (2 + 4 + closePayload.count),
Data(closeFrame.prefix(2)) == Data([0x88, (0x80 | UInt8(closePayload.count))]),
try unmaskedPayload(from: closeFrame) == closePayload else {
NSLog("Invalid close payload")
throw InternalServerError.badBody
}
}

} catch {
let badBodyCloseFrame = Data([0x88, 0x08, 0x03, 0xEA, 0x42, 0x75, 0x68, 0x42, 0x79, 0x65])
Expand Down
9 changes: 7 additions & 2 deletions Tests/Foundation/Tests/TestDecimal.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//

import Foundation
import XCTest
#if NS_FOUNDATION_ALLOWS_TESTABLE_IMPORT
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was necessary for TestDecimal to build on macOS against the correct Foundation library. See also e9e183c#r1011034654 . I was mistaken in thinking this change was unnecessary.

#if canImport(SwiftFoundation) && !DEPLOYMENT_RUNTIME_OBJC
@testable import SwiftFoundation
#else
@testable import Foundation
#endif
#endif

class TestDecimal: XCTestCase {

Expand Down
97 changes: 95 additions & 2 deletions Tests/Foundation/Tests/TestURLSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1894,9 +1894,17 @@ class TestURLSession: LoopbackServerTest {
}

try await task.sendPing()

wait(for: [delegate.expectation], timeout: 50)

do {
_ = try await task.receive()
XCTFail("Expected to throw when receiving on closed task")
} catch {
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
}

let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
"urlSession(_:webSocketTask:didCloseWith:reason:)",
"urlSession(_:task:didCompleteWithError:)" ]
Expand Down Expand Up @@ -1925,15 +1933,98 @@ class TestURLSession: LoopbackServerTest {
wait(for: [delegate.expectation], timeout: 50)

let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
"urlSession(_:webSocketTask:didCloseWith:reason:)",
"urlSession(_:task:didCompleteWithError:)" ]
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")

XCTAssertEqual(task.closeCode, .normalClosure)
XCTAssertEqual(task.closeReason, "BuhBye".data(using: .utf8))
}
#endif

func test_webSocketAbruptClose() async throws {
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
guard URLSessionWebSocketTask.supportsWebSockets else {
print("libcurl lacks WebSockets support, skipping \(#function)")
return
}

let urlString = "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/abrupt-close"
let url = try XCTUnwrap(URL(string: urlString))
let request = URLRequest(url: url)

let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect"))
let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4)

do {
_ = try await task.receive()
XCTFail("Expected to throw when server closes connection")
} catch {
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorBadServerResponse)
}

wait(for: [delegate.expectation], timeout: 50)

do {
_ = try await task.receive()
XCTFail("Expected to throw when receiving on closed connection")
} catch {
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorBadServerResponse)
}

let callbacks = [ "urlSession(_:task:didCompleteWithError:)" ]
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")

XCTAssertEqual(task.closeCode, .invalid)
XCTAssertEqual(task.closeReason, nil)
}

func test_webSocketSemiAbruptClose() async throws {
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
guard URLSessionWebSocketTask.supportsWebSockets else {
print("libcurl lacks WebSockets support, skipping \(#function)")
return
}

let urlString = "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/semi-abrupt-close"
let url = try XCTUnwrap(URL(string: urlString))
let request = URLRequest(url: url)

let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect"))
let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4)

do {
_ = try await task.receive()
XCTFail("Expected to throw when server closes connection")
} catch {
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
}

wait(for: [delegate.expectation], timeout: 50)

do {
_ = try await task.receive()
XCTFail("Expected to throw when receiving on closed connection")
} catch {
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
}

let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
"urlSession(_:webSocketTask:didCloseWith:reason:)",
"urlSession(_:task:didCompleteWithError:)" ]
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")

XCTAssertEqual(task.closeCode, .normalClosure)
XCTAssertEqual(task.closeReason, nil)
}
#endif

static var allTests: [(String, (TestURLSession) -> () throws -> Void)] {
var retVal = [
("test_dataTaskWithURL", test_dataTaskWithURL),
Expand Down Expand Up @@ -2011,6 +2102,8 @@ class TestURLSession: LoopbackServerTest {
retVal.append(contentsOf: [
("test_webSocket", asyncTest(test_webSocket)),
("test_webSocketSpecificProtocol", asyncTest(test_webSocketSpecificProtocol)),
("test_webSocketAbruptClose", asyncTest(test_webSocketAbruptClose)),
("test_webSocketSemiAbruptClose", asyncTest(test_webSocketSemiAbruptClose)),
])
}
return retVal
Expand Down