Skip to content

Fix Request streaming memory leak #665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
}

var state = State.idle
let request: HTTPClient.Request
let requestMethod: HTTPMethod
let requestHost: String

static let maxByteBufferSize = Int(UInt32.max)

Expand Down Expand Up @@ -408,14 +409,15 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
maxBodySize <= Self.maxByteBufferSize,
"maxBodyLength is not allowed to exceed 2^32 because ByteBuffer can not store more bytes"
)
self.request = request
self.requestMethod = request.method
self.requestHost = request.host
self.maxBodySize = maxBodySize
}

public func didReceiveHead(task: HTTPClient.Task<Response>, _ head: HTTPResponseHead) -> EventLoopFuture<Void> {
switch self.state {
case .idle:
if self.request.method != .HEAD,
if self.requestMethod != .HEAD,
let contentLength = head.headers.first(name: "Content-Length"),
let announcedBodySize = Int(contentLength),
announcedBodySize > self.maxBodySize {
Expand Down Expand Up @@ -481,9 +483,9 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
case .idle:
preconditionFailure("no head received before end")
case .head(let head):
return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: nil)
return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: nil)
case .body(let head, let body):
return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: body)
return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: body)
case .end:
preconditionFailure("request already processed")
case .error(let error):
Expand Down
47 changes: 23 additions & 24 deletions Sources/AsyncHTTPClient/RequestBag+StateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ extension HTTPClient {
extension RequestBag {
struct StateMachine {
fileprivate enum State {
case initialized
case queued(HTTPRequestScheduler)
case initialized(RedirectHandler<Delegate.Response>?)
case queued(HTTPRequestScheduler, RedirectHandler<Delegate.Response>?)
/// if the deadline was exceeded while in the `.queued(_:)` state,
/// we wait until the request pool fails the request with a potential more descriptive error message,
/// if a connection failure has occured while the request was queued.
case deadlineExceededWhileQueued
case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState)
case finished(error: Error?)
case redirected(HTTPRequestExecutor, Int, HTTPResponseHead, URL)
case redirected(HTTPRequestExecutor, RedirectHandler<Delegate.Response>, Int, HTTPResponseHead, URL)
case modifying
}

Expand All @@ -55,23 +55,22 @@ extension RequestBag {
case eof
}

case initialized
case initialized(RedirectHandler<Delegate.Response>?)
case buffering(CircularBuffer<ByteBuffer>, next: Next)
case waitingForRemote
}

private var state: State = .initialized
private let redirectHandler: RedirectHandler<Delegate.Response>?
private var state: State

init(redirectHandler: RedirectHandler<Delegate.Response>?) {
self.redirectHandler = redirectHandler
self.state = .initialized(redirectHandler)
}
}
}

extension RequestBag.StateMachine {
mutating func requestWasQueued(_ scheduler: HTTPRequestScheduler) {
guard case .initialized = self.state else {
guard case .initialized(let redirectHandler) = self.state else {
// There might be a race between `requestWasQueued` and `willExecuteRequest`:
//
// If the request is created and passed to the HTTPClient on thread A, it will move into
Expand All @@ -91,7 +90,7 @@ extension RequestBag.StateMachine {
return
}

self.state = .queued(scheduler)
self.state = .queued(scheduler, redirectHandler)
}

enum WillExecuteRequestAction {
Expand All @@ -102,8 +101,8 @@ extension RequestBag.StateMachine {

mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> WillExecuteRequestAction {
switch self.state {
case .initialized, .queued:
self.state = .executing(executor, .initialized, .initialized)
case .initialized(let redirectHandler), .queued(_, let redirectHandler):
self.state = .executing(executor, .initialized, .initialized(redirectHandler))
return .none
case .deadlineExceededWhileQueued:
let error: Error = HTTPClientError.deadlineExceeded
Expand All @@ -127,8 +126,8 @@ extension RequestBag.StateMachine {
case .initialized, .queued, .deadlineExceededWhileQueued:
preconditionFailure("A request stream can only be resumed, if the request was started")

case .executing(let executor, .initialized, .initialized):
self.state = .executing(executor, .producing, .initialized)
case .executing(let executor, .initialized, .initialized(let redirectHandler)):
self.state = .executing(executor, .producing, .initialized(redirectHandler))
return .startWriter

case .executing(_, .producing, _):
Expand Down Expand Up @@ -299,11 +298,11 @@ extension RequestBag.StateMachine {
case .initialized, .queued, .deadlineExceededWhileQueued:
preconditionFailure("How can we receive a response, if the request hasn't started yet.")
case .executing(let executor, let requestState, let responseState):
guard case .initialized = responseState else {
guard case .initialized(let redirectHandler) = responseState else {
preconditionFailure("If we receive a response, we must not have received something else before")
}

if let redirectURL = self.redirectHandler?.redirectTarget(
if let redirectHandler = redirectHandler, let redirectURL = redirectHandler.redirectTarget(
status: head.status,
responseHeaders: head.headers
) {
Expand All @@ -312,11 +311,11 @@ extension RequestBag.StateMachine {
// smaller than 3kb.
switch head.contentLength {
case .some(0...(HTTPClient.maxBodySizeRedirectResponse)), .none:
self.state = .redirected(executor, 0, head, redirectURL)
self.state = .redirected(executor, redirectHandler, 0, head, redirectURL)
return .signalBodyDemand(executor)
case .some:
self.state = .finished(error: HTTPClientError.cancelled)
return .redirect(executor, self.redirectHandler!, head, redirectURL)
return .redirect(executor, redirectHandler, head, redirectURL)
}
} else {
self.state = .executing(executor, requestState, .buffering(.init(), next: .askExecutorForMore))
Expand Down Expand Up @@ -369,15 +368,15 @@ extension RequestBag.StateMachine {
} else {
return .none
}
case .redirected(let executor, var receivedBytes, let head, let redirectURL):
case .redirected(let executor, let redirectHandler, var receivedBytes, let head, let redirectURL):
let partsLength = buffer.reduce(into: 0) { $0 += $1.readableBytes }
receivedBytes += partsLength

if receivedBytes > HTTPClient.maxBodySizeRedirectResponse {
self.state = .finished(error: HTTPClientError.cancelled)
return .redirect(executor, self.redirectHandler!, head, redirectURL)
return .redirect(executor, redirectHandler, head, redirectURL)
} else {
self.state = .redirected(executor, receivedBytes, head, redirectURL)
self.state = .redirected(executor, redirectHandler, receivedBytes, head, redirectURL)
return .signalBodyDemand(executor)
}

Expand Down Expand Up @@ -428,9 +427,9 @@ extension RequestBag.StateMachine {
self.state = .executing(executor, requestState, .buffering(newChunks, next: .eof))
return .consume(first)

case .redirected(_, _, let head, let redirectURL):
case .redirected(_, let redirectHandler, _, let head, let redirectURL):
self.state = .finished(error: nil)
return .redirect(self.redirectHandler!, head, redirectURL)
return .redirect(redirectHandler, head, redirectURL)

case .finished(error: .some):
return .none
Expand Down Expand Up @@ -553,7 +552,7 @@ extension RequestBag.StateMachine {

mutating func deadlineExceeded() -> DeadlineExceededAction {
switch self.state {
case .queued(let queuer):
case .queued(let queuer, _):
/// We do not fail the request immediately because we want to give the scheduler a chance of throwing a better error message
/// We therefore depend on the scheduler failing the request after we cancel the request.
self.state = .deadlineExceededWhileQueued
Expand Down Expand Up @@ -582,7 +581,7 @@ extension RequestBag.StateMachine {
case .initialized:
self.state = .finished(error: error)
return .failTask(error, nil, nil)
case .queued(let queuer):
case .queued(let queuer, _):
self.state = .finished(error: error)
return .failTask(error, queuer, nil)
case .executing(let executor, let requestState, .buffering(_, next: .eof)):
Expand Down
3 changes: 2 additions & 1 deletion Sources/AsyncHTTPClient/RequestBag.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private let delegate: Delegate
private let request: HTTPClient.Request
private var request: HTTPClient.Request

// the request state is synchronized on the task eventLoop
private var state: StateMachine
Expand Down Expand Up @@ -126,6 +126,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
guard let body = self.request.body else {
preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream")
}
self.request.body = nil

let writer = HTTPClient.Body.StreamWriter {
self.writeNextRequestPart($0)
Expand Down
1 change: 1 addition & 0 deletions Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ extension RequestBagTests {
("testRedirectWith3KBBody", testRedirectWith3KBBody),
("testRedirectWith4KBBodyAnnouncedInResponseHead", testRedirectWith4KBBodyAnnouncedInResponseHead),
("testRedirectWith4KBBodyNotAnnouncedInResponseHead", testRedirectWith4KBBodyNotAnnouncedInResponseHead),
("testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise", testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise),
]
}
}
49 changes: 49 additions & 0 deletions Tests/AsyncHTTPClientTests/RequestBagTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Logging
import NIOCore
import NIOEmbedded
import NIOHTTP1
import NIOPosix
import XCTest

final class RequestBagTests: XCTestCase {
Expand Down Expand Up @@ -836,6 +837,54 @@ final class RequestBagTests: XCTestCase {

XCTAssertTrue(redirectTriggered)
}

func testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise() {
final class LeakDetector {}

let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) }

let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group))
defer { XCTAssertNoThrow(try httpClient.shutdown().wait()) }

let httpBin = HTTPBin()
defer { XCTAssertNoThrow(try httpBin.shutdown()) }

var leakDetector = LeakDetector()

do {
var maybeRequest: HTTPClient.Request?
XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/", method: .POST))
guard var request = maybeRequest else { return XCTFail("Expected to have a request here") }

let writerPromise = group.any().makePromise(of: HTTPClient.Body.StreamWriter.self)
let donePromise = group.any().makePromise(of: Void.self)
request.body = .stream { [leakDetector] writer in
_ = leakDetector
writerPromise.succeed(writer)
return donePromise.futureResult
}

let resultFuture = httpClient.execute(request: request)
request.body = nil
writerPromise.futureResult.whenSuccess { writer in
writer.write(.byteBuffer(ByteBuffer(string: "hello"))).map {
print("written")
}.cascade(to: donePromise)
}
XCTAssertNoThrow(try donePromise.futureResult.wait())
print("HTTP sent")

var result: HTTPClient.Response?
XCTAssertNoThrow(result = try resultFuture.wait())

XCTAssertEqual(.ok, result?.status)
let body = result?.body.map { String(buffer: $0) }
XCTAssertNotNil(body)
print("HTTP done")
}
XCTAssertTrue(isKnownUniquelyReferenced(&leakDetector))
}
}

extension HTTPClient.Task {
Expand Down