Skip to content

Commit 6b89811

Browse files
committed
Fix Request streaming memory leak
1 parent 59bfb96 commit 6b89811

File tree

5 files changed

+82
-30
lines changed

5 files changed

+82
-30
lines changed

Sources/AsyncHTTPClient/HTTPHandler.swift

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
379379
}
380380

381381
var state = State.idle
382-
let request: HTTPClient.Request
382+
let requestMethod: HTTPMethod
383+
let requestHost: String
383384

384385
static let maxByteBufferSize = Int(UInt32.max)
385386

@@ -408,14 +409,15 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
408409
maxBodySize <= Self.maxByteBufferSize,
409410
"maxBodyLength is not allowed to exceed 2^32 because ByteBuffer can not store more bytes"
410411
)
411-
self.request = request
412+
self.requestMethod = request.method
413+
self.requestHost = request.host
412414
self.maxBodySize = maxBodySize
413415
}
414416

415417
public func didReceiveHead(task: HTTPClient.Task<Response>, _ head: HTTPResponseHead) -> EventLoopFuture<Void> {
416418
switch self.state {
417419
case .idle:
418-
if self.request.method != .HEAD,
420+
if self.requestMethod != .HEAD,
419421
let contentLength = head.headers.first(name: "Content-Length"),
420422
let announcedBodySize = Int(contentLength),
421423
announcedBodySize > self.maxBodySize {
@@ -481,9 +483,9 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
481483
case .idle:
482484
preconditionFailure("no head received before end")
483485
case .head(let head):
484-
return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: nil)
486+
return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: nil)
485487
case .body(let head, let body):
486-
return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: body)
488+
return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: body)
487489
case .end:
488490
preconditionFailure("request already processed")
489491
case .error(let error):

Sources/AsyncHTTPClient/RequestBag+StateMachine.swift

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ extension HTTPClient {
2929
extension RequestBag {
3030
struct StateMachine {
3131
fileprivate enum State {
32-
case initialized
33-
case queued(HTTPRequestScheduler)
32+
case initialized(RedirectHandler<Delegate.Response>?)
33+
case queued(HTTPRequestScheduler, RedirectHandler<Delegate.Response>?)
3434
/// if the deadline was exceeded while in the `.queued(_:)` state,
3535
/// we wait until the request pool fails the request with a potential more descriptive error message,
3636
/// if a connection failure has occured while the request was queued.
3737
case deadlineExceededWhileQueued
3838
case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState)
3939
case finished(error: Error?)
40-
case redirected(HTTPRequestExecutor, Int, HTTPResponseHead, URL)
40+
case redirected(HTTPRequestExecutor, RedirectHandler<Delegate.Response>, Int, HTTPResponseHead, URL)
4141
case modifying
4242
}
4343

@@ -55,23 +55,22 @@ extension RequestBag {
5555
case eof
5656
}
5757

58-
case initialized
58+
case initialized(RedirectHandler<Delegate.Response>?)
5959
case buffering(CircularBuffer<ByteBuffer>, next: Next)
6060
case waitingForRemote
6161
}
6262

63-
private var state: State = .initialized
64-
private let redirectHandler: RedirectHandler<Delegate.Response>?
63+
private var state: State
6564

6665
init(redirectHandler: RedirectHandler<Delegate.Response>?) {
67-
self.redirectHandler = redirectHandler
66+
self.state = .initialized(redirectHandler)
6867
}
6968
}
7069
}
7170

7271
extension RequestBag.StateMachine {
7372
mutating func requestWasQueued(_ scheduler: HTTPRequestScheduler) {
74-
guard case .initialized = self.state else {
73+
guard case .initialized(let redirectHandler) = self.state else {
7574
// There might be a race between `requestWasQueued` and `willExecuteRequest`:
7675
//
7776
// If the request is created and passed to the HTTPClient on thread A, it will move into
@@ -91,7 +90,7 @@ extension RequestBag.StateMachine {
9190
return
9291
}
9392

94-
self.state = .queued(scheduler)
93+
self.state = .queued(scheduler, redirectHandler)
9594
}
9695

9796
enum WillExecuteRequestAction {
@@ -102,8 +101,8 @@ extension RequestBag.StateMachine {
102101

103102
mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> WillExecuteRequestAction {
104103
switch self.state {
105-
case .initialized, .queued:
106-
self.state = .executing(executor, .initialized, .initialized)
104+
case .initialized(let redirectHandler), .queued(_, let redirectHandler):
105+
self.state = .executing(executor, .initialized, .initialized(redirectHandler))
107106
return .none
108107
case .deadlineExceededWhileQueued:
109108
let error: Error = HTTPClientError.deadlineExceeded
@@ -127,8 +126,8 @@ extension RequestBag.StateMachine {
127126
case .initialized, .queued, .deadlineExceededWhileQueued:
128127
preconditionFailure("A request stream can only be resumed, if the request was started")
129128

130-
case .executing(let executor, .initialized, .initialized):
131-
self.state = .executing(executor, .producing, .initialized)
129+
case .executing(let executor, .initialized, .initialized(let redirectHandler)):
130+
self.state = .executing(executor, .producing, .initialized(redirectHandler))
132131
return .startWriter
133132

134133
case .executing(_, .producing, _):
@@ -299,11 +298,11 @@ extension RequestBag.StateMachine {
299298
case .initialized, .queued, .deadlineExceededWhileQueued:
300299
preconditionFailure("How can we receive a response, if the request hasn't started yet.")
301300
case .executing(let executor, let requestState, let responseState):
302-
guard case .initialized = responseState else {
301+
guard case .initialized(let redirectHandler) = responseState else {
303302
preconditionFailure("If we receive a response, we must not have received something else before")
304303
}
305304

306-
if let redirectURL = self.redirectHandler?.redirectTarget(
305+
if let redirectHandler = redirectHandler, let redirectURL = redirectHandler.redirectTarget(
307306
status: head.status,
308307
responseHeaders: head.headers
309308
) {
@@ -312,11 +311,11 @@ extension RequestBag.StateMachine {
312311
// smaller than 3kb.
313312
switch head.contentLength {
314313
case .some(0...(HTTPClient.maxBodySizeRedirectResponse)), .none:
315-
self.state = .redirected(executor, 0, head, redirectURL)
314+
self.state = .redirected(executor, redirectHandler, 0, head, redirectURL)
316315
return .signalBodyDemand(executor)
317316
case .some:
318317
self.state = .finished(error: HTTPClientError.cancelled)
319-
return .redirect(executor, self.redirectHandler!, head, redirectURL)
318+
return .redirect(executor, redirectHandler, head, redirectURL)
320319
}
321320
} else {
322321
self.state = .executing(executor, requestState, .buffering(.init(), next: .askExecutorForMore))
@@ -369,15 +368,15 @@ extension RequestBag.StateMachine {
369368
} else {
370369
return .none
371370
}
372-
case .redirected(let executor, var receivedBytes, let head, let redirectURL):
371+
case .redirected(let executor, let redirectHandler, var receivedBytes, let head, let redirectURL):
373372
let partsLength = buffer.reduce(into: 0) { $0 += $1.readableBytes }
374373
receivedBytes += partsLength
375374

376375
if receivedBytes > HTTPClient.maxBodySizeRedirectResponse {
377376
self.state = .finished(error: HTTPClientError.cancelled)
378-
return .redirect(executor, self.redirectHandler!, head, redirectURL)
377+
return .redirect(executor, redirectHandler, head, redirectURL)
379378
} else {
380-
self.state = .redirected(executor, receivedBytes, head, redirectURL)
379+
self.state = .redirected(executor, redirectHandler, receivedBytes, head, redirectURL)
381380
return .signalBodyDemand(executor)
382381
}
383382

@@ -428,9 +427,9 @@ extension RequestBag.StateMachine {
428427
self.state = .executing(executor, requestState, .buffering(newChunks, next: .eof))
429428
return .consume(first)
430429

431-
case .redirected(_, _, let head, let redirectURL):
430+
case .redirected(_, let redirectHandler, _, let head, let redirectURL):
432431
self.state = .finished(error: nil)
433-
return .redirect(self.redirectHandler!, head, redirectURL)
432+
return .redirect(redirectHandler, head, redirectURL)
434433

435434
case .finished(error: .some):
436435
return .none
@@ -553,7 +552,7 @@ extension RequestBag.StateMachine {
553552

554553
mutating func deadlineExceeded() -> DeadlineExceededAction {
555554
switch self.state {
556-
case .queued(let queuer):
555+
case .queued(let queuer, _):
557556
/// We do not fail the request immediately because we want to give the scheduler a chance of throwing a better error message
558557
/// We therefore depend on the scheduler failing the request after we cancel the request.
559558
self.state = .deadlineExceededWhileQueued
@@ -582,7 +581,7 @@ extension RequestBag.StateMachine {
582581
case .initialized:
583582
self.state = .finished(error: error)
584583
return .failTask(error, nil, nil)
585-
case .queued(let queuer):
584+
case .queued(let queuer, _):
586585
self.state = .finished(error: error)
587586
return .failTask(error, queuer, nil)
588587
case .executing(let executor, let requestState, .buffering(_, next: .eof)):

Sources/AsyncHTTPClient/RequestBag.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
3333
}
3434

3535
private let delegate: Delegate
36-
private let request: HTTPClient.Request
36+
private var request: HTTPClient.Request
3737

3838
// the request state is synchronized on the task eventLoop
3939
private var state: StateMachine
@@ -126,6 +126,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
126126
guard let body = self.request.body else {
127127
preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream")
128128
}
129+
self.request.body = nil
129130

130131
let writer = HTTPClient.Body.StreamWriter {
131132
self.writeNextRequestPart($0)

Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ extension RequestBagTests {
4040
("testRedirectWith3KBBody", testRedirectWith3KBBody),
4141
("testRedirectWith4KBBodyAnnouncedInResponseHead", testRedirectWith4KBBodyAnnouncedInResponseHead),
4242
("testRedirectWith4KBBodyNotAnnouncedInResponseHead", testRedirectWith4KBBodyNotAnnouncedInResponseHead),
43+
("testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise", testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise),
4344
]
4445
}
4546
}

Tests/AsyncHTTPClientTests/RequestBagTests.swift

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import Logging
1717
import NIOCore
1818
import NIOEmbedded
1919
import NIOHTTP1
20+
import NIOPosix
2021
import XCTest
2122

2223
final class RequestBagTests: XCTestCase {
@@ -836,6 +837,54 @@ final class RequestBagTests: XCTestCase {
836837

837838
XCTAssertTrue(redirectTriggered)
838839
}
840+
841+
func testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise() {
842+
final class LeakDetector {}
843+
844+
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
845+
defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) }
846+
847+
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group))
848+
defer { XCTAssertNoThrow(try httpClient.shutdown().wait()) }
849+
850+
let httpBin = HTTPBin()
851+
defer { XCTAssertNoThrow(try httpBin.shutdown()) }
852+
853+
var leakDetector = LeakDetector()
854+
855+
do {
856+
var maybeRequest: HTTPClient.Request?
857+
XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/", method: .POST))
858+
guard var request = maybeRequest else { return XCTFail("Expected to have a request here") }
859+
860+
let writerPromise = group.any().makePromise(of: HTTPClient.Body.StreamWriter.self)
861+
let donePromise = group.any().makePromise(of: Void.self)
862+
request.body = .stream { [leakDetector] writer in
863+
_ = leakDetector
864+
writerPromise.succeed(writer)
865+
return donePromise.futureResult
866+
}
867+
868+
let resultFuture = httpClient.execute(request: request)
869+
request.body = nil
870+
writerPromise.futureResult.whenSuccess { writer in
871+
writer.write(.byteBuffer(ByteBuffer(string: "hello"))).map {
872+
print("written")
873+
}.cascade(to: donePromise)
874+
}
875+
XCTAssertNoThrow(try donePromise.futureResult.wait())
876+
print("HTTP sent")
877+
878+
var result: HTTPClient.Response?
879+
XCTAssertNoThrow(result = try resultFuture.wait())
880+
881+
XCTAssertEqual(.ok, result?.status)
882+
let body = result?.body.map { String(buffer: $0) }
883+
XCTAssertNotNil(body)
884+
print("HTTP done")
885+
}
886+
XCTAssertTrue(isKnownUniquelyReferenced(&leakDetector))
887+
}
839888
}
840889

841890
extension HTTPClient.Task {

0 commit comments

Comments
 (0)