Skip to content

Allow HTTPClientRequest to be executed multiple times if body is an AsyncSequence #620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,25 @@

#if compiler(>=5.5.2) && canImport(_Concurrency)
import struct Foundation.URL
import NIOCore
import NIOHTTP1

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest {
struct Prepared {
enum Body {
case asyncSequence(
length: RequestBodyLength,
nextBodyPart: (ByteBufferAllocator) async throws -> ByteBuffer?
)
case sequence(
length: RequestBodyLength,
canBeConsumedMultipleTimes: Bool,
makeCompleteBody: (ByteBufferAllocator) -> ByteBuffer
)
case byteBuffer(ByteBuffer)
}

var url: URL
var poolKey: ConnectionPool.Key
var requestFramingMetadata: RequestFramingMetadata
Expand Down Expand Up @@ -53,11 +67,25 @@ extension HTTPClientRequest.Prepared {
uri: deconstructedURL.uri,
headers: headers
),
body: request.body
body: request.body.map { .init($0) }
)
}
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest.Prepared.Body {
init(_ body: HTTPClientRequest.Body) {
switch body.mode {
case .asyncSequence(let length, let makeAsyncIterator):
self = .asyncSequence(length: length, nextBodyPart: makeAsyncIterator())
case .sequence(let length, let canBeConsumedMultipleTimes, let makeCompleteBody):
self = .sequence(length: length, canBeConsumedMultipleTimes: canBeConsumedMultipleTimes, makeCompleteBody: makeCompleteBody)
case .byteBuffer(let byteBuffer):
self = .byteBuffer(byteBuffer)
}
}
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension RequestBodyLength {
init(_ body: HTTPClientRequest.Body?) {
Expand Down
50 changes: 36 additions & 14 deletions Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,26 @@ extension HTTPClientRequest {
public struct Body {
@usableFromInline
internal enum Mode {
case asyncSequence(length: RequestBodyLength, (ByteBufferAllocator) async throws -> ByteBuffer?)
case sequence(length: RequestBodyLength, canBeConsumedMultipleTimes: Bool, (ByteBufferAllocator) -> ByteBuffer)
/// - parameters:
/// - length: complete body length.
/// If `length` is `.known`, `nextBodyPart` is not allowed to produce more bytes than `length` defines.
/// - makeAsyncIterator: Creates a new async iterator under the hood and returns a function which will call `next()` on it.
/// The returned function then produce the next body buffer asynchronously.
/// We use a closure as abstraction instead of an existential to enable specialization.
case asyncSequence(
length: RequestBodyLength,
makeAsyncIterator: () -> ((ByteBufferAllocator) async throws -> ByteBuffer?)
)
/// - parameters:
/// - length: complete body length.
/// If `length` is `.known`, `nextBodyPart` is not allowed to produce more bytes than `length` defines.
/// - canBeConsumedMultipleTimes: if `makeBody` can be called multiple times and returns the same result.
/// - makeCompleteBody: function to produce the complete body.
case sequence(
length: RequestBodyLength,
canBeConsumedMultipleTimes: Bool,
makeCompleteBody: (ByteBufferAllocator) -> ByteBuffer
)
case byteBuffer(ByteBuffer)
}

Expand Down Expand Up @@ -180,9 +198,11 @@ extension HTTPClientRequest.Body {
_ sequenceOfBytes: SequenceOfBytes,
length: Length
) -> Self where SequenceOfBytes.Element == ByteBuffer {
var iterator = sequenceOfBytes.makeAsyncIterator()
let body = self.init(.asyncSequence(length: length.storage) { _ -> ByteBuffer? in
try await iterator.next()
let body = self.init(.asyncSequence(length: length.storage) {
var iterator = sequenceOfBytes.makeAsyncIterator()
return { _ -> ByteBuffer? in
try await iterator.next()
}
})
return body
}
Expand All @@ -205,16 +225,18 @@ extension HTTPClientRequest.Body {
_ bytes: Bytes,
length: Length
) -> Self where Bytes.Element == UInt8 {
var iterator = bytes.makeAsyncIterator()
let body = self.init(.asyncSequence(length: length.storage) { allocator -> ByteBuffer? in
var buffer = allocator.buffer(capacity: 1024) // TODO: Magic number
while buffer.writableBytes > 0, let byte = try await iterator.next() {
buffer.writeInteger(byte)
}
if buffer.readableBytes > 0 {
return buffer
let body = self.init(.asyncSequence(length: length.storage) {
var iterator = bytes.makeAsyncIterator()
return { allocator -> ByteBuffer? in
var buffer = allocator.buffer(capacity: 1024) // TODO: Magic number
while buffer.writableBytes > 0, let byte = try await iterator.next() {
buffer.writeInteger(byte)
}
if buffer.readableBytes > 0 {
return buffer
}
return nil
}
return nil
})
return body
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ extension Transaction: HTTPExecutableRequest {
break

case .startStream(let allocator):
switch self.request.body?.mode {
switch self.request.body {
case .asyncSequence(_, let next):
// it is safe to call this async here. it dispatches...
self.continueRequestBodyStream(allocator, next: next)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ extension AsyncAwaitEndToEndTests {
("testRedirectChangesHostHeader", testRedirectChangesHostHeader),
("testShutdown", testShutdown),
("testCancelingBodyDoesNotCrash", testCancelingBodyDoesNotCrash),
("testAsyncSequenceReuse", testAsyncSequenceReuse),
]
}
}
38 changes: 38 additions & 0 deletions Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,44 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
}
#endif
}

func testAsyncSequenceReuse() {
#if compiler(>=5.5.2) && canImport(_Concurrency)
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
XCTAsyncTest {
let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() }
defer { XCTAssertNoThrow(try bin.shutdown()) }
let client = makeDefaultHTTPClient()
defer { XCTAssertNoThrow(try client.syncShutdown()) }
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
request.method = .POST
request.body = .stream([
ByteBuffer(string: "1"),
ByteBuffer(string: "2"),
ByteBuffer(string: "34"),
].asAsyncSequence(), length: .unknown)

guard let response1 = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
) else { return }
XCTAssertEqual(response1.headers["content-length"], [])
guard let body = await XCTAssertNoThrowWithResult(
try await response1.body.collect()
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))

guard let response2 = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
) else { return }
XCTAssertEqual(response2.headers["content-length"], [])
guard let body = await XCTAssertNoThrowWithResult(
try await response2.body.collect()
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))
}
#endif
}
}

#if compiler(>=5.5.2) && canImport(_Concurrency)
Expand Down
4 changes: 2 additions & 2 deletions Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,11 @@ private struct LengthMismatch: Error {
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension Optional where Wrapped == HTTPClientRequest.Body {
extension Optional where Wrapped == HTTPClientRequest.Prepared.Body {
/// Accumulates all data from `self` into a single `ByteBuffer` and checks that the user specified length matches
/// the length of the accumulated data.
fileprivate func read() async throws -> ByteBuffer {
switch self?.mode {
switch self {
case .none:
return ByteBuffer()
case .byteBuffer(let buffer):
Expand Down