Skip to content

Commit 68d200e

Browse files
committed
Allow HTTPClientRequest to be executed multiple times if body is an AsyncSequence
1 parent d764c1a commit 68d200e

File tree

6 files changed

+107
-18
lines changed

6 files changed

+107
-18
lines changed

Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,25 @@
1414

1515
#if compiler(>=5.5.2) && canImport(_Concurrency)
1616
import struct Foundation.URL
17+
import NIOCore
1718
import NIOHTTP1
1819

1920
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
2021
extension HTTPClientRequest {
2122
struct Prepared {
23+
enum Body {
24+
case asyncSequence(
25+
length: RequestBodyLength,
26+
nextBodyPart: (ByteBufferAllocator) async throws -> ByteBuffer?
27+
)
28+
case sequence(
29+
length: RequestBodyLength,
30+
canBeConsumedMultipleTimes: Bool,
31+
makeCompleteBody: (ByteBufferAllocator) -> ByteBuffer
32+
)
33+
case byteBuffer(ByteBuffer)
34+
}
35+
2236
var url: URL
2337
var poolKey: ConnectionPool.Key
2438
var requestFramingMetadata: RequestFramingMetadata
@@ -53,11 +67,25 @@ extension HTTPClientRequest.Prepared {
5367
uri: deconstructedURL.uri,
5468
headers: headers
5569
),
56-
body: request.body
70+
body: request.body.map { .init($0) }
5771
)
5872
}
5973
}
6074

75+
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
76+
extension HTTPClientRequest.Prepared.Body {
77+
init(_ body: HTTPClientRequest.Body) {
78+
switch body.mode {
79+
case .asyncSequence(let length, let makeAsyncIterator):
80+
self = .asyncSequence(length: length, nextBodyPart: makeAsyncIterator())
81+
case .sequence(let length, let canBeConsumedMultipleTimes, let makeCompleteBody):
82+
self = .sequence(length: length, canBeConsumedMultipleTimes: canBeConsumedMultipleTimes, makeCompleteBody: makeCompleteBody)
83+
case .byteBuffer(let byteBuffer):
84+
self = .byteBuffer(byteBuffer)
85+
}
86+
}
87+
}
88+
6189
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
6290
extension RequestBodyLength {
6391
init(_ body: HTTPClientRequest.Body?) {

Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,26 @@ extension HTTPClientRequest {
5050
public struct Body {
5151
@usableFromInline
5252
internal enum Mode {
53-
case asyncSequence(length: RequestBodyLength, (ByteBufferAllocator) async throws -> ByteBuffer?)
54-
case sequence(length: RequestBodyLength, canBeConsumedMultipleTimes: Bool, (ByteBufferAllocator) -> ByteBuffer)
53+
/// - parameters:
54+
/// - length: complete body length.
55+
/// If `length` is `.known`, `nextBodyPart` is not allowed to produce more bytes than `length` defines.
56+
/// - makeAsyncIterator: Creates a new async iterator under the hood and returns a function which will call `next()` on it.
57+
/// The returned function then produce the next body buffer asynchronously.
58+
/// We use a closure as abstraction instead of an existential to enable specialization.
59+
case asyncSequence(
60+
length: RequestBodyLength,
61+
makeAsyncIterator: () -> ((ByteBufferAllocator) async throws -> ByteBuffer?)
62+
)
63+
/// - parameters:
64+
/// - length: complete body length.
65+
/// If `length` is `.known`, `nextBodyPart` is not allowed to produce more bytes than `length` defines.
66+
/// - canBeConsumedMultipleTimes: if `makeBody` can be called multiple times and returns the same result.
67+
/// - makeCompleteBody: function to produce the complete body.
68+
case sequence(
69+
length: RequestBodyLength,
70+
canBeConsumedMultipleTimes: Bool,
71+
makeCompleteBody: (ByteBufferAllocator) -> ByteBuffer
72+
)
5573
case byteBuffer(ByteBuffer)
5674
}
5775

@@ -180,9 +198,11 @@ extension HTTPClientRequest.Body {
180198
_ sequenceOfBytes: SequenceOfBytes,
181199
length: Length
182200
) -> Self where SequenceOfBytes.Element == ByteBuffer {
183-
var iterator = sequenceOfBytes.makeAsyncIterator()
184-
let body = self.init(.asyncSequence(length: length.storage) { _ -> ByteBuffer? in
185-
try await iterator.next()
201+
let body = self.init(.asyncSequence(length: length.storage) {
202+
var iterator = sequenceOfBytes.makeAsyncIterator()
203+
return { _ -> ByteBuffer? in
204+
try await iterator.next()
205+
}
186206
})
187207
return body
188208
}
@@ -205,16 +225,18 @@ extension HTTPClientRequest.Body {
205225
_ bytes: Bytes,
206226
length: Length
207227
) -> Self where Bytes.Element == UInt8 {
208-
var iterator = bytes.makeAsyncIterator()
209-
let body = self.init(.asyncSequence(length: length.storage) { allocator -> ByteBuffer? in
210-
var buffer = allocator.buffer(capacity: 1024) // TODO: Magic number
211-
while buffer.writableBytes > 0, let byte = try await iterator.next() {
212-
buffer.writeInteger(byte)
213-
}
214-
if buffer.readableBytes > 0 {
215-
return buffer
228+
let body = self.init(.asyncSequence(length: length.storage) {
229+
var iterator = bytes.makeAsyncIterator()
230+
return { allocator -> ByteBuffer? in
231+
var buffer = allocator.buffer(capacity: 1024) // TODO: Magic number
232+
while buffer.writableBytes > 0, let byte = try await iterator.next() {
233+
buffer.writeInteger(byte)
234+
}
235+
if buffer.readableBytes > 0 {
236+
return buffer
237+
}
238+
return nil
216239
}
217-
return nil
218240
})
219241
return body
220242
}

Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ extension Transaction: HTTPExecutableRequest {
194194
break
195195

196196
case .startStream(let allocator):
197-
switch self.request.body?.mode {
197+
switch self.request.body {
198198
case .asyncSequence(_, let next):
199199
// it is safe to call this async here. it dispatches...
200200
self.continueRequestBodyStream(allocator, next: next)

Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ extension AsyncAwaitEndToEndTests {
4444
("testRedirectChangesHostHeader", testRedirectChangesHostHeader),
4545
("testShutdown", testShutdown),
4646
("testCancelingBodyDoesNotCrash", testCancelingBodyDoesNotCrash),
47+
("testAsyncSequenceReuse", testAsyncSequenceReuse),
4748
]
4849
}
4950
}

Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,44 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
598598
}
599599
#endif
600600
}
601+
602+
func testAsyncSequenceReuse() {
603+
#if compiler(>=5.5.2) && canImport(_Concurrency)
604+
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
605+
XCTAsyncTest {
606+
let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() }
607+
defer { XCTAssertNoThrow(try bin.shutdown()) }
608+
let client = makeDefaultHTTPClient()
609+
defer { XCTAssertNoThrow(try client.syncShutdown()) }
610+
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
611+
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/")
612+
request.method = .POST
613+
request.body = .stream([
614+
ByteBuffer(string: "1"),
615+
ByteBuffer(string: "2"),
616+
ByteBuffer(string: "34"),
617+
].asAsyncSequence(), length: .unknown)
618+
619+
guard let response1 = await XCTAssertNoThrowWithResult(
620+
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
621+
) else { return }
622+
XCTAssertEqual(response1.headers["content-length"], [])
623+
guard let body = await XCTAssertNoThrowWithResult(
624+
try await response1.body.collect()
625+
) else { return }
626+
XCTAssertEqual(body, ByteBuffer(string: "1234"))
627+
628+
guard let response2 = await XCTAssertNoThrowWithResult(
629+
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
630+
) else { return }
631+
XCTAssertEqual(response2.headers["content-length"], [])
632+
guard let body = await XCTAssertNoThrowWithResult(
633+
try await response2.body.collect()
634+
) else { return }
635+
XCTAssertEqual(body, ByteBuffer(string: "1234"))
636+
}
637+
#endif
638+
}
601639
}
602640

603641
#if compiler(>=5.5.2) && canImport(_Concurrency)

Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,11 +489,11 @@ private struct LengthMismatch: Error {
489489
}
490490

491491
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
492-
extension Optional where Wrapped == HTTPClientRequest.Body {
492+
extension Optional where Wrapped == HTTPClientRequest.Prepared.Body {
493493
/// Accumulates all data from `self` into a single `ByteBuffer` and checks that the user specified length matches
494494
/// the length of the accumulated data.
495495
fileprivate func read() async throws -> ByteBuffer {
496-
switch self?.mode {
496+
switch self {
497497
case .none:
498498
return ByteBuffer()
499499
case .byteBuffer(let buffer):

0 commit comments

Comments
 (0)