-
Notifications
You must be signed in to change notification settings - Fork 125
Collect function fix #672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Collect function fix #672
Changes from 11 commits
668b33a
6276e51
104eff5
2ad4a6b
e919181
df627a5
8156844
9b63815
e8d8f85
0de9e27
045a063
9dc6f33
ecbcd90
ae17940
483a3f2
14cfdd4
e60d65b
2fbef23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -32,16 +32,30 @@ public struct HTTPClientResponse: Sendable { | |||||||
/// The body of this HTTP response. | ||||||||
public var body: Body | ||||||||
|
||||||||
|
||||||||
@inlinable | ||||||||
init( | ||||||||
bag: Transaction, | ||||||||
version: HTTPVersion, | ||||||||
status: HTTPResponseStatus, | ||||||||
headers: HTTPHeaders | ||||||||
version: HTTPVersion = .http1_1, | ||||||||
status: HTTPResponseStatus = .ok, | ||||||||
headers: HTTPHeaders = [:], | ||||||||
body: Body = Body(), | ||||||||
requestMethod: HTTPMethod? | ||||||||
) { | ||||||||
self.version = version | ||||||||
self.status = status | ||||||||
self.headers = headers | ||||||||
self.body = Body(TransactionBody(bag)) | ||||||||
self.body = body | ||||||||
} | ||||||||
|
||||||||
init( | ||||||||
bag: Transaction, | ||||||||
version: HTTPVersion, | ||||||||
status: HTTPResponseStatus, | ||||||||
headers: HTTPHeaders, | ||||||||
requestMethod: HTTPMethod | ||||||||
) { | ||||||||
let contentLength = HTTPClientResponse.expectedContentLength(requestMethod: requestMethod, headers: headers, status: status) | ||||||||
self.init(version: version, status: status, headers: headers, body: .init(TransactionBody(bag, expextedContentLength: contentLength)), requestMethod: requestMethod) | ||||||||
} | ||||||||
|
||||||||
@inlinable public init( | ||||||||
|
@@ -50,10 +64,7 @@ public struct HTTPClientResponse: Sendable { | |||||||
headers: HTTPHeaders = [:], | ||||||||
body: Body = Body() | ||||||||
) { | ||||||||
self.version = version | ||||||||
self.status = status | ||||||||
self.headers = headers | ||||||||
self.body = body | ||||||||
self.init(version: version, status: status, headers: headers, body: body, requestMethod: nil) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
|
@@ -83,6 +94,56 @@ extension HTTPClientResponse { | |||||||
@inlinable public func makeAsyncIterator() -> AsyncIterator { | ||||||||
.init(storage: self.storage.makeAsyncIterator()) | ||||||||
} | ||||||||
|
||||||||
@inlinable init(storage: Storage) { | ||||||||
self.storage = storage | ||||||||
} | ||||||||
|
||||||||
/// Accumulates `Body` of ``ByteBuffer``s into a single ``ByteBuffer``. | ||||||||
/// - Parameters: | ||||||||
/// - maxBytes: The maximum number of bytes this method is allowed to accumulate | ||||||||
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`. | ||||||||
/// - Returns: the number of bytes collected over time | ||||||||
@inlinable public func collect(upTo maxBytes: Int) async throws -> ByteBuffer { | ||||||||
switch storage { | ||||||||
case .transaction(let transactionBody): | ||||||||
if let contentLength = transactionBody.expectedContentLength { | ||||||||
if contentLength > maxBytes { | ||||||||
throw NIOTooManyBytesError() | ||||||||
} | ||||||||
} | ||||||||
case .anyAsyncSequence: | ||||||||
break | ||||||||
} | ||||||||
|
||||||||
/// <#Description#> | ||||||||
/// - Parameters: | ||||||||
/// - body: <#body description#> | ||||||||
/// - maxBytes: <#maxBytes description#> | ||||||||
/// - Throws: <#description#> | ||||||||
/// - Returns: <#description#> | ||||||||
func collect<Body: AsyncSequence>(_ body: Body, maxBytes: Int) async throws -> ByteBuffer where Body.Element == ByteBuffer { | ||||||||
dnadoba marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
try await body.collect(upTo: maxBytes) | ||||||||
} | ||||||||
return try await collect(self, maxBytes: maxBytes) | ||||||||
|
||||||||
} | ||||||||
|
||||||||
} | ||||||||
} | ||||||||
|
||||||||
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) | ||||||||
extension HTTPClientResponse { | ||||||||
static func expectedContentLength(requestMethod: HTTPMethod, headers: HTTPHeaders, status: HTTPResponseStatus) -> Int? { | ||||||||
if status == .notModified { | ||||||||
return 0 | ||||||||
} else if requestMethod == .HEAD { | ||||||||
return 0 | ||||||||
} | ||||||||
else { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
let contentLength = headers["content-length"].first.flatMap({Int($0, radix: 10)}) | ||||||||
return contentLength | ||||||||
} | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
|
@@ -132,10 +193,10 @@ extension HTTPClientResponse.Body.Storage.AsyncIterator: AsyncIteratorProtocol { | |||||||
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) | ||||||||
extension HTTPClientResponse.Body { | ||||||||
init(_ body: TransactionBody) { | ||||||||
self.init(.transaction(body)) | ||||||||
self.init(storage: .transaction(body)) | ||||||||
} | ||||||||
|
||||||||
@usableFromInline init(_ storage: Storage) { | ||||||||
@inlinable init(_ storage: Storage) { | ||||||||
self.storage = storage | ||||||||
} | ||||||||
|
||||||||
|
@@ -146,7 +207,7 @@ extension HTTPClientResponse.Body { | |||||||
@inlinable public static func stream<SequenceOfBytes>( | ||||||||
_ sequenceOfBytes: SequenceOfBytes | ||||||||
) -> Self where SequenceOfBytes: AsyncSequence & Sendable, SequenceOfBytes.Element == ByteBuffer { | ||||||||
self.init(.anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition))) | ||||||||
Self.init(storage: .anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition))) | ||||||||
} | ||||||||
|
||||||||
public static func bytes(_ byteBuffer: ByteBuffer) -> Self { | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,7 +114,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { return } | ||
XCTAssertEqual(response.headers["content-length"], ["4"]) | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response.body.collect() | ||
try await response.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(body, ByteBuffer(string: "1234")) | ||
} | ||
|
@@ -137,7 +137,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { return } | ||
XCTAssertEqual(response.headers["content-length"], []) | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response.body.collect() | ||
try await response.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(body, ByteBuffer(string: "1234")) | ||
} | ||
|
@@ -160,7 +160,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { return } | ||
XCTAssertEqual(response.headers["content-length"], []) | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response.body.collect() | ||
try await response.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(body, ByteBuffer(string: "1234")) | ||
} | ||
|
@@ -183,7 +183,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { return } | ||
XCTAssertEqual(response.headers["content-length"], ["4"]) | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response.body.collect() | ||
try await response.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(body, ByteBuffer(string: "1234")) | ||
} | ||
|
@@ -210,7 +210,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { return } | ||
XCTAssertEqual(response.headers["content-length"], []) | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response.body.collect() | ||
try await response.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(body, ByteBuffer(string: "1234")) | ||
} | ||
|
@@ -233,7 +233,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { return } | ||
XCTAssertEqual(response.headers["content-length"], []) | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response.body.collect() | ||
try await response.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(body, ByteBuffer(string: "1234")) | ||
} | ||
|
@@ -522,7 +522,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { | ||
return | ||
} | ||
guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect()) else { return } | ||
guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect(upTo: 1024)) else { return } | ||
var maybeRequestInfo: RequestInfo? | ||
XCTAssertNoThrow(maybeRequestInfo = try JSONDecoder().decode(RequestInfo.self, from: body)) | ||
guard let requestInfo = maybeRequestInfo else { return } | ||
|
@@ -583,7 +583,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { return } | ||
XCTAssertEqual(response1.headers["content-length"], []) | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response1.body.collect() | ||
try await response1.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(body, ByteBuffer(string: "1234")) | ||
|
||
|
@@ -592,12 +592,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
) else { return } | ||
XCTAssertEqual(response2.headers["content-length"], []) | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response2.body.collect() | ||
try await response2.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(body, ByteBuffer(string: "1234")) | ||
} | ||
} | ||
|
||
func testRejectsInvalidCharactersInHeaderFieldNames_http1() { | ||
carolinacass marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._rejectsInvalidCharactersInHeaderFieldNames(mode: .http1_1(ssl: true)) | ||
} | ||
|
@@ -745,17 +744,32 @@ final class AsyncAwaitEndToEndTests: XCTestCase { | |
XCTAssertEqual(response.version, .http2) | ||
} | ||
} | ||
} | ||
|
||
extension AsyncSequence where Element == ByteBuffer { | ||
func collect() async rethrows -> ByteBuffer { | ||
try await self.reduce(into: ByteBuffer()) { accumulatingBuffer, nextBuffer in | ||
var nextBuffer = nextBuffer | ||
accumulatingBuffer.writeBuffer(&nextBuffer) | ||
|
||
func testSimpleContentLengthError() { | ||
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } | ||
XCTAsyncTest { | ||
let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } | ||
defer { XCTAssertNoThrow(try bin.shutdown()) } | ||
let client = makeDefaultHTTPClient() | ||
defer { XCTAssertNoThrow(try client.syncShutdown()) } | ||
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) | ||
var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") | ||
request.method = .GET | ||
request.body = .bytes(ByteBuffer(string: "1234")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hrm, I think this muddies the waters of the test somewhat. I don't think this test would fail if you took away the content-length checks, because the actual What you want is for the HTTPbin server to send you back a content-length that's too long and no body at all. That should still throw this error. |
||
|
||
guard var response = await XCTAssertNoThrowWithResult( | ||
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) | ||
) else { return } | ||
await XCTAssertThrowsError( | ||
try await response.body.collect(upTo: 3) | ||
) { | ||
XCTAssertEqualTypeAndValue($0, NIOTooManyBytesError()) | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
||
struct AnySendableSequence<Element>: @unchecked Sendable { | ||
private let wrapped: AnySequence<Element> | ||
init<WrappedSequence: Sequence & Sendable>( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the AsyncHTTPClient open source project | ||
// | ||
// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors | ||
// Licensed under Apache License v2.0 | ||
// | ||
// See LICENSE.txt for license information | ||
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
|
||
|
||
@testable import AsyncHTTPClient | ||
import Logging | ||
import NIOCore | ||
import XCTest | ||
|
||
|
||
private func makeDefaultHTTPClient( | ||
eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .createNew | ||
) -> HTTPClient { | ||
var config = HTTPClient.Configuration() | ||
config.tlsConfiguration = .clientDefault | ||
config.tlsConfiguration?.certificateVerification = .none | ||
config.httpVersion = .automatic | ||
return HTTPClient( | ||
eventLoopGroupProvider: eventLoopGroupProvider, | ||
configuration: config, | ||
backgroundActivityLogger: Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) | ||
) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After we have removed the last test case, this is no longer used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still here |
||
|
||
final class HTTPClientResponseTests: XCTestCase { | ||
|
||
func testSimpleResponse() { | ||
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .ok) | ||
XCTAssertEqual(response, 1025) | ||
} | ||
|
||
func testSimpleResponseNotModified() { | ||
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .notModified) | ||
XCTAssertEqual(response, 0) | ||
} | ||
|
||
func testSimpleResponseHeadRequestMethod() { | ||
let response = HTTPClientResponse.expectedContentLength(requestMethod: .HEAD, headers: ["content-length": "1025"], status: .ok) | ||
XCTAssertEqual(response, 0) | ||
} | ||
dnadoba marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
func testReponseInitWithStatus() { | ||
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } | ||
XCTAsyncTest { | ||
var response = HTTPClientResponse(status: .notModified , requestMethod: .GET) | ||
response.headers.replaceOrAdd(name: "content-length", value: "1025") | ||
guard let body = await XCTAssertNoThrowWithResult( | ||
try await response.body.collect(upTo: 1024) | ||
) else { return } | ||
XCTAssertEqual(0, body.readableBytes) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is no longer testing what we want. I think we can just remove it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still here |
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the
requestMethod
argument name again