Skip to content

Limit max recursion depth delivering body parts #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions Sources/AsyncHTTPClient/RequestBag.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ import NIOHTTP1
import NIOSSL

final class RequestBag<Delegate: HTTPClientResponseDelegate> {
/// Defends against the call stack getting too large when consuming body parts.
///
/// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users
/// one at a time.
private static var maxConsumeBodyPartStackDepth: Int {
50
}

let task: HTTPClient.Task<Delegate.Response>
var eventLoop: EventLoop {
self.task.eventLoop
Expand All @@ -30,6 +38,9 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
// the request state is synchronized on the task eventLoop
private var state: StateMachine

// the consume body part stack depth is synchronized on the task event loop.
private var consumeBodyPartStackDepth: Int

// MARK: HTTPClientTask properties

var logger: Logger {
Expand All @@ -55,6 +66,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
self.eventLoopPreference = eventLoopPreference
self.task = task
self.state = .init(redirectHandler: redirectHandler)
self.consumeBodyPartStackDepth = 0
self.request = request
self.connectionDeadline = connectionDeadline
self.requestOptions = requestOptions
Expand Down Expand Up @@ -290,16 +302,39 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
private func consumeMoreBodyData0(resultOfPreviousConsume result: Result<Void, Error>) {
self.task.eventLoop.assertInEventLoop()

// We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart`
// future to be returned to us completed. If it is, we will recurse back into this method. To
// break that recursion we have a max stack depth which we increment and decrement in this method:
// if it gets too large, instead of recurring we'll insert an `eventLoop.execute`, which will
// manually break the recursion and unwind the stack.
//
// Note that we don't bother starting this at the various other call sites that _begin_ stacks
// that risk ending up in this loop. That's because we don't need an accurate count: our limit is
// a best-effort target anyway, one stack frame here or there does not put us at risk. We're just
// trying to prevent ourselves looping out of control.
self.consumeBodyPartStackDepth += 1
defer {
self.consumeBodyPartStackDepth -= 1
assert(self.consumeBodyPartStackDepth >= 0)
}

let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result)

switch consumptionAction {
case .consume(let byteBuffer):
self.delegate.didReceiveBodyPart(task: self.task, byteBuffer)
.hop(to: self.task.eventLoop)
.whenComplete {
switch $0 {
.whenComplete { result in
switch result {
case .success:
self.consumeMoreBodyData0(resultOfPreviousConsume: $0)
if self.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth {
self.consumeMoreBodyData0(resultOfPreviousConsume: result)
} else {
// We need to unwind the stack, let's take a break.
self.task.eventLoop.execute {
self.consumeMoreBodyData0(resultOfPreviousConsume: result)
}
}
case .failure(let error):
self.fail(error)
}
Expand Down
1 change: 1 addition & 0 deletions Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ extension HTTP2ClientTests {
("testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline", testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline),
("testStressCancelingRunningRequestFromDifferentThreads", testStressCancelingRunningRequestFromDifferentThreads),
("testPlatformConnectErrorIsForwardedOnTimeout", testPlatformConnectErrorIsForwardedOnTimeout),
("testMassiveDownload", testMassiveDownload),
]
}
}
13 changes: 13 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,19 @@ class HTTP2ClientTests: XCTestCase {
)
}
}

func testMassiveDownload() {
let bin = HTTPBin(.http2(compress: false))
defer { XCTAssertNoThrow(try bin.shutdown()) }
let client = self.makeDefaultHTTPClient()
defer { XCTAssertNoThrow(try client.syncShutdown()) }
var response: HTTPClient.Response?
XCTAssertNoThrow(response = try client.get(url: "https://localhost:\(bin.port)/mega-chunked").wait())

XCTAssertEqual(.ok, response?.status)
XCTAssertEqual(response?.version, .http2)
XCTAssertEqual(response?.body?.readableBytes, 10_000)
}
}

private final class HeadReceivedCallback: HTTPClientResponseDelegate {
Expand Down
19 changes: 19 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,22 @@ internal final class HTTPBinHandler: ChannelInboundHandler {
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
}

func writeManyChunks(context: ChannelHandlerContext) {
// This tests receiving a lot of tiny chunks: they must all be sent in a single flush or the test doesn't work.
let headers = HTTPHeaders([("Transfer-Encoding", "chunked")])

context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil)
let message = ByteBuffer(integer: UInt8(ascii: "a"))

// This number (10k) is load-bearing and a bit magic: it has been experimentally verified as being sufficient to blow the stack
// in the old implementation on all testing platforms. Please don't change it without good reason.
for _ in 0..<10_000 {
context.write(wrapOutboundOut(.body(.byteBuffer(message))), promise: nil)
}

context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
self.isServingRequest = true
switch self.unwrapInboundIn(data) {
Expand Down Expand Up @@ -863,6 +879,9 @@ internal final class HTTPBinHandler: ChannelInboundHandler {
case "/chunked":
self.writeChunked(context: context)
return
case "/mega-chunked":
self.writeManyChunks(context: context)
return
case "/close-on-response":
var headers = self.responseHeaders
headers.replaceOrAdd(name: "connection", value: "close")
Expand Down
1 change: 1 addition & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ extension HTTPClientTests {
("testRequestSpecificTLS", testRequestSpecificTLS),
("testConnectionPoolSizeConfigValueIsRespected", testConnectionPoolSizeConfigValueIsRespected),
("testRequestWithHeaderTransferEncodingIdentityDoesNotFail", testRequestWithHeaderTransferEncodingIdentityDoesNotFail),
("testMassiveDownload", testMassiveDownload),
]
}
}
9 changes: 9 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3454,4 +3454,13 @@ class HTTPClientTests: XCTestCase {

XCTAssertNoThrow(try client.execute(request: request).wait())
}

func testMassiveDownload() {
var response: HTTPClient.Response?
XCTAssertNoThrow(response = try self.defaultClient.get(url: "\(self.defaultHTTPBinURLPrefix)mega-chunked").wait())

XCTAssertEqual(.ok, response?.status)
XCTAssertEqual(response?.version, .http1_1)
XCTAssertEqual(response?.body?.readableBytes, 10_000)
}
}