diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AsyncRequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/AsyncAwait/AsyncRequestBag+StateMachine.swift new file mode 100644 index 000000000..120ec9045 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/AsyncRequestBag+StateMachine.swift @@ -0,0 +1,570 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if compiler(>=5.5) && canImport(_Concurrency) +import Logging +import NIOCore +import NIOHTTP1 + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +extension AsyncRequestBag { + struct StateMachine { + struct ExecutionContext { + let executor: HTTPRequestExecutor + let allocator: ByteBufferAllocator + let continuation: UnsafeContinuation + } + + private enum State { + case initialized + case waiting(UnsafeContinuation) + case queued(UnsafeContinuation, HTTPRequestScheduler) + case executing(ExecutionContext, RequestStreamState, ResponseStreamState) + case finished(error: Error?, HTTPClientResponse.Body.IteratorStream.ID?) + } + + fileprivate enum RequestStreamState { + case initialized + case producing + case paused + case finished + } + + fileprivate enum ResponseStreamState { + enum Next { + case askExecutorForMore + case error(Error) + case eof + } + + case initialized + case waitingForStream(CircularBuffer, next: Next) + case buffering(HTTPClientResponse.Body.IteratorStream.ID, CircularBuffer, next: Next) + case waitingForRemote(HTTPClientResponse.Body.IteratorStream.ID, UnsafeContinuation) + case finished(HTTPClientResponse.Body.IteratorStream.ID, UnsafeContinuation) + } + + private var state: State + + init() { + self.state = .initialized + } + + mutating func registerContinuation(_ continuation: UnsafeContinuation) { + guard case .initialized = self.state else { + preconditionFailure("Invalid state: \(self.state)") + } + + self.state = .waiting(continuation) + } + + mutating func requestWasQueued(_ scheduler: HTTPRequestScheduler) { + guard case .waiting(let continuation) = self.state else { + // There might be a race between `requestWasQueued` and `willExecuteRequest`: + // + // If the request is created and passed to the HTTPClient on thread A, it will move into + // the connection pool lock in thread A. If no connection is available, thread A will + // add the request to the waiters and leave the connection pool lock. + // `requestWasQueued` will be called outside the connection pool lock on thread A. + // However if thread B has a connection that becomes available and thread B enters the + // connection pool lock directly after thread A, the request will be immediately + // scheduled for execution on thread B. After the thread B has left the lock it will + // call `willExecuteRequest` directly after. + // + // Having an order in the connection pool lock, does not guarantee an order in calling: + // `requestWasQueued` and `willExecuteRequest`. + // + // For this reason we must check the state here... If we are not `.initialized`, we are + // already executing. + return + } + + self.state = .queued(continuation, scheduler) + } + + enum FailAction { + case none + /// fail response before head received. scheduler and executor are exclusive here. + case failResponseHead(UnsafeContinuation, Error, HTTPRequestScheduler?, HTTPRequestExecutor?) + case failResponseStream(UnsafeContinuation, Error, HTTPRequestExecutor) + } + + mutating func fail(_ error: Error) -> FailAction { + switch self.state { + case .initialized: + preconditionFailure("") + + case .waiting(let continuation): + self.state = .finished(error: error, nil) + return .failResponseHead(continuation, error, nil, nil) + + case .queued(let continuation, let scheduler): + self.state = .finished(error: error, nil) + return .failResponseHead(continuation, error, scheduler, nil) + + case .executing(let context, _, .initialized): + self.state = .finished(error: error, nil) + return .failResponseHead(context.continuation, error, nil, context.executor) + + case .executing(_, _, .waitingForStream(_, next: .error)), + .executing(_, _, .buffering(_, _, next: .error)): + return .none + + case .executing(let context, let requestStreamState, .waitingForStream(let buffer, next: .askExecutorForMore)), + .executing(let context, let requestStreamState, .waitingForStream(let buffer, next: .eof)): + switch requestStreamState { + case .initialized: + preconditionFailure("Invalid state") + + case .paused, .finished: + self.state = .executing(context, requestStreamState, .waitingForStream(buffer, next: .error(error))) + return .none + + case .producing: + self.state = .executing(context, .paused, .waitingForStream(buffer, next: .error(error))) + return .none + } + + case .executing(let context, let requestStreamState, .buffering(let streamID, let buffer, next: .askExecutorForMore)), + .executing(let context, let requestStreamState, .buffering(let streamID, let buffer, next: .eof)): + + switch requestStreamState { + case .initialized: + preconditionFailure("Invalid state") + + case .paused, .finished: + self.state = .executing(context, requestStreamState, .buffering(streamID, buffer, next: .error(error))) + return .none + + case .producing: + self.state = .executing(context, .paused, .buffering(streamID, buffer, next: .error(error))) + return .none + } + + case .executing(let context, _, .waitingForRemote(let streamID, let continuation)): + self.state = .finished(error: error, streamID) + return .failResponseStream(continuation, error, context.executor) + + case .finished(error: _, _): + return .none + + case .executing(let context, _, .finished(let streamID, let continuation)): + self.state = .finished(error: error, streamID) + return .failResponseStream(continuation, error, context.executor) + } + } + + // MARK: - Request - + + enum StartExecutionAction { + case cancel(HTTPRequestExecutor) + case none + } + + mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> StartExecutionAction { + switch self.state { + case .waiting(let continuation), .queued(let continuation, _): + let context = ExecutionContext( + executor: executor, + allocator: .init(), + continuation: continuation + ) + self.state = .executing(context, .initialized, .initialized) + return .none + case .finished(error: .some, .none): + return .cancel(executor) + case .initialized, + .executing, + .finished(error: .none, _), + .finished(error: .some, .some): + preconditionFailure("Invalid state: \(self.state)") + } + } + + enum ResumeProducingAction { + case resumeStream(ByteBufferAllocator) + case none + } + + mutating func resumeRequestBodyStream() -> ResumeProducingAction { + switch self.state { + case .initialized, .waiting, .queued: + preconditionFailure("A request stream can only be resumed, if the request was started") + + case .executing(let context, .initialized, .initialized): + self.state = .executing(context, .producing, .initialized) + return .resumeStream(context.allocator) + + case .executing(_, .producing, _): + preconditionFailure("Expected that resume is only called when if we were paused before") + + case .executing(let context, .paused, let responseState): + self.state = .executing(context, .producing, responseState) + return .resumeStream(context.allocator) + + case .executing(_, .finished, _): + // the channels writability changed to writable after we have forwarded all the + // request bytes. Can be ignored. + return .none + + case .executing(_, .initialized, .waitingForStream), + .executing(_, .initialized, .buffering), + .executing(_, .initialized, .waitingForRemote), + .executing(_, .initialized, .finished): + preconditionFailure("Invalid states: Response can not be received before request") + + case .finished: + return .none + } + } + + mutating func pauseRequestBodyStream() { + switch self.state { + case .initialized, + .waiting, + .queued, + .executing(_, .initialized, _): + preconditionFailure("A request stream can only be resumed, if the request was started") + + case .executing(let context, .producing, let responseSteam): + self.state = .executing(context, .paused, responseSteam) + + case .executing(_, .paused, _), + .executing(_, .finished, _), + .finished: + // the channels writability changed to writable after we have forwarded all the + // request bytes. Can be ignored. + break + } + } + + enum NextWriteAction { + case write(ByteBuffer, HTTPRequestExecutor, continue: Bool) + case ignore + } + + func producedNextRequestPart(_ part: ByteBuffer) -> NextWriteAction { + switch self.state { + case .initialized, + .waiting, + .queued, + .executing(_, .initialized, _), + .executing(_, .finished, _): + preconditionFailure("A request stream can only be resumed, if the request was started") + + case .executing(let context, .producing, _): + return .write(part, context.executor, continue: true) + + case .executing(let context, .paused, _): + return .write(part, context.executor, continue: false) + + case .finished: + return .ignore + } + } + + enum ProduceErrorAction { + case none + case informRequestAboutFailure(Error, cancelExecutor: HTTPRequestExecutor, failResponseStream: UnsafeContinuation?) + } + + mutating func failedToProduceNextRequestPart(_ error: Error) -> ProduceErrorAction { + switch self.state { + case .initialized, + .waiting, + .queued, + .executing(_, .initialized, _), + .executing(_, .finished, _): + preconditionFailure("A request stream can only be resumed, if the request was started") + + case .executing(let context, .producing, .initialized), + .executing(let context, .producing, .waitingForStream), + .executing(let context, .paused, .initialized), + .executing(let context, .paused, .waitingForStream): + self.state = .finished(error: error, nil) + return .informRequestAboutFailure(error, cancelExecutor: context.executor, failResponseStream: nil) + + case .executing(let context, .producing, .buffering(let streamID, _, next: _)), + .executing(let context, .paused, .buffering(let streamID, _, next: _)): + self.state = .finished(error: error, streamID) + return .informRequestAboutFailure(error, cancelExecutor: context.executor, failResponseStream: nil) + + case .executing(let context, .producing, .waitingForRemote(let streamID, let continuation)), + .executing(let context, .paused, .waitingForRemote(let streamID, let continuation)), + .executing(let context, .producing, .finished(let streamID, let continuation)), + .executing(let context, .paused, .finished(let streamID, let continuation)): + self.state = .finished(error: error, streamID) + return .informRequestAboutFailure(error, cancelExecutor: context.executor, failResponseStream: continuation) + + case .finished: + return .none + } + } + + enum FinishAction { + case forwardStreamFinished(HTTPRequestExecutor) + case none + } + + mutating func finishRequestBodyStream() -> FinishAction { + switch self.state { + case .initialized, + .waiting, + .queued, + .executing(_, .initialized, _), + .executing(_, .finished, _): + preconditionFailure("Invalid state: \(self.state)") + + case .executing(let context, .producing, let responseState), + .executing(let context, .paused, let responseState): + self.state = .executing(context, .finished, responseState) + return .forwardStreamFinished(context.executor) + + case .finished: + return .none + } + } + + // MARK: - Response - + + enum ReceiveResponseHeadAction { + case succeedResponseHead(HTTPResponseHead, UnsafeContinuation) + case none + } + + mutating func receiveResponseHead(_ head: HTTPResponseHead) -> ReceiveResponseHeadAction { + switch self.state { + case .initialized, + .waiting, + .queued, + .executing(_, _, .waitingForStream), + .executing(_, _, .buffering), + .executing(_, _, .waitingForRemote): + preconditionFailure("How can we receive a response, if the request hasn't started yet.") + + case .executing(let context, let requestState, .initialized): + self.state = .executing(context, requestState, .waitingForStream(.init(), next: .askExecutorForMore)) + return .succeedResponseHead(head, context.continuation) + + case .finished(error: .some, _): + return .none + + case .executing(_, _, .finished), + .finished(error: .none, _): + preconditionFailure("How can the request be finished without error, before receiving response head?") + } + } + + enum ReceiveResponsePartAction { + case none + case succeedContinuation(UnsafeContinuation, ByteBuffer) + } + + mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponsePartAction { + switch self.state { + case .initialized, .waiting, .queued: + preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") + case .executing(_, _, .initialized): + preconditionFailure("If we receive a response body, we must have received a head before") + + case .executing(let context, let requestState, .buffering(let streamID, var currentBuffer, next: let next)): + guard case .askExecutorForMore = next else { + preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + } + + if currentBuffer.isEmpty { + currentBuffer = buffer + } else { + currentBuffer.append(contentsOf: buffer) + } + self.state = .executing(context, requestState, .buffering(streamID, currentBuffer, next: next)) + return .none + + case .executing(let executor, let requestState, .waitingForStream(var currentBuffer, next: let next)): + guard case .askExecutorForMore = next else { + preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + } + + if currentBuffer.isEmpty { + currentBuffer = buffer + } else { + currentBuffer.append(contentsOf: buffer) + } + self.state = .executing(executor, requestState, .waitingForStream(currentBuffer, next: next)) + return .none + + case .executing(let executor, let requestState, .waitingForRemote(let streamID, let continuation)): + var buffer = buffer + let first = buffer.removeFirst() + self.state = .executing(executor, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) + return .succeedContinuation(continuation, first) + + case .finished(error: .some, _): + return .none + case .executing(_, _, .finished), + .finished(error: .none, _): + preconditionFailure("How can the request be finished without error, before receiving response head?") + } + } + + enum ConsumeAction { + case succeedContinuation(UnsafeContinuation, ByteBuffer?) + case failContinuation(UnsafeContinuation, Error) + case askExecutorForMore(HTTPRequestExecutor) + } + + struct TriedToRegisteredASecondConsumer: Error {} + + mutating func consumeNextResponsePart( + streamID: HTTPClientResponse.Body.IteratorStream.ID, + continuation: UnsafeContinuation + ) -> ConsumeAction { + switch self.state { + case .initialized, + .waiting, + .queued, + .executing(_, _, .initialized): + preconditionFailure("If we receive a response body, we must have received a head before") + + case .executing(_, _, .finished(_, _)): + preconditionFailure("This is an invalid state at this point. We are waiting for the request stream to finish to succeed or response stream.") + + case .executing(let context, let requestState, .waitingForStream(var buffer, next: .askExecutorForMore)): + if buffer.isEmpty { + self.state = .executing(context, requestState, .waitingForRemote(streamID, continuation)) + return .askExecutorForMore(context.executor) + } else { + let toReturn = buffer.removeFirst() + self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) + return .succeedContinuation(continuation, toReturn) + } + + case .executing(_, _, .waitingForStream(_, next: .error(let error))): + self.state = .finished(error: error, streamID) + return .failContinuation(continuation, error) + + case .executing(_, _, .waitingForStream(let buffer, next: .eof)) where buffer.isEmpty: + self.state = .finished(error: nil, streamID) + return .succeedContinuation(continuation, nil) + + case .executing(let context, let requestState, .waitingForStream(var buffer, next: .eof)): + assert(!buffer.isEmpty) + let toReturn = buffer.removeFirst() + self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .eof)) + return .succeedContinuation(continuation, toReturn) + + case .executing(let context, let requestState, .buffering(let streamID, var buffer, next: .askExecutorForMore)): + if buffer.isEmpty { + self.state = .executing(context, requestState, .waitingForRemote(streamID, continuation)) + return .askExecutorForMore(context.executor) + } else { + let toReturn = buffer.removeFirst() + self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) + return .succeedContinuation(continuation, toReturn) + } + + case .executing(_, _, .buffering(let registeredStreamID, _, next: .error(let error))): + guard registeredStreamID == streamID else { + return .failContinuation(continuation, TriedToRegisteredASecondConsumer()) + } + self.state = .finished(error: error, registeredStreamID) + return .failContinuation(continuation, error) + + case .executing(_, _, .buffering(let registeredStreamID, let buffer, next: .eof)) where buffer.isEmpty: + guard registeredStreamID == streamID else { + return .failContinuation(continuation, TriedToRegisteredASecondConsumer()) + } + self.state = .finished(error: nil, registeredStreamID) + return .succeedContinuation(continuation, nil) + + case .executing(let context, let requestState, .buffering(let streamID, var buffer, next: .eof)): + assert(!buffer.isEmpty) + let toReturn = buffer.removeFirst() + self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .eof)) + return .succeedContinuation(continuation, toReturn) + + case .executing(_, _, .waitingForRemote(let registeredStreamID, let continuation)): + if registeredStreamID != streamID { + return .failContinuation(continuation, TriedToRegisteredASecondConsumer()) + } + preconditionFailure("") + + case .finished(error: .some(let error), let registeredStreamID): + guard registeredStreamID == streamID else { + return .failContinuation(continuation, TriedToRegisteredASecondConsumer()) + } + return .failContinuation(continuation, error) + case .finished(error: .none, let registeredStreamID): + guard registeredStreamID == streamID else { + return .failContinuation(continuation, TriedToRegisteredASecondConsumer()) + } + return .succeedContinuation(continuation, nil) + } + } + + enum ReceiveResponseEndAction { + case succeedContinuation(UnsafeContinuation, ByteBuffer) + case succeedRequest(UnsafeContinuation) + case none + } + + mutating func succeedRequest(_ newChunks: CircularBuffer?) -> ReceiveResponseEndAction { + switch self.state { + case .initialized, .waiting, .queued: + preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") + + case .executing(_, _, .initialized): + preconditionFailure("If we receive a response end, we must have received a head before") + + case .executing(let context, let requestState, .waitingForStream(var buffer, next: .askExecutorForMore)): + if let newChunks = newChunks, !newChunks.isEmpty { + buffer.append(contentsOf: newChunks) + } + self.state = .executing(context, requestState, .waitingForStream(buffer, next: .eof)) + return .none + + case .executing(let context, let requestState, .waitingForRemote(let streamID, let continuation)): + if var newChunks = newChunks, !newChunks.isEmpty { + let first = newChunks.removeFirst() + self.state = .executing(context, requestState, .buffering(streamID, newChunks, next: .eof)) + return .succeedContinuation(continuation, first) + } + + self.state = .finished(error: nil, streamID) + return .succeedRequest(continuation) + + case .executing(let context, let requestState, .buffering(let streamID, var buffer, next: .askExecutorForMore)): + if let newChunks = newChunks, !newChunks.isEmpty { + buffer.append(contentsOf: newChunks) + } + self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .eof)) + return .none + + case .finished(error: .some, _): + return .none + + case .finished(error: .none, _): + preconditionFailure("How can the request be finished without error, before receiving response head?") + + case .executing(_, _, .waitingForStream(_, next: .error)), + .executing(_, _, .waitingForStream(_, next: .eof)), + .executing(_, _, .buffering(_, _, next: .error)), + .executing(_, _, .buffering(_, _, next: .eof)), + .executing(_, _, .finished(_, _)): + preconditionFailure("How can the request be succeeded, if we received an error or eof before") + } + } + } +} + +#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AsyncRequestBag.swift b/Sources/AsyncHTTPClient/AsyncAwait/AsyncRequestBag.swift new file mode 100644 index 000000000..f68c031fc --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/AsyncRequestBag.swift @@ -0,0 +1,324 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if compiler(>=5.5) && canImport(_Concurrency) +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOHTTP1 +import NIOSSL + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +class AsyncRequestBag { + let logger: Logger + // TODO: store `PreparedRequest` as a single property + let request: HTTPClientRequest.Prepared + var requestHead: HTTPRequestHead { self.request.head } + var requestBody: HTTPClientRequest.Body? { self.request.body } + var poolKey: ConnectionPool.Key { self.request.poolKey } + var requestFramingMetadata: RequestFramingMetadata { self.request.requestFramingMetadata } + + let connectionDeadline: NIODeadline + let preferredEventLoop: EventLoop + let requestOptions: RequestOptions + + private let stateLock = Lock() + private var state: StateMachine = .init() + + init( + request: HTTPClientRequest.Prepared, + requestOptions: RequestOptions, + logger: Logger, + connectionDeadline: NIODeadline, + preferredEventLoop: EventLoop, + responseContinuation: UnsafeContinuation + ) { + self.request = request + self.requestOptions = requestOptions + self.logger = logger + self.connectionDeadline = connectionDeadline + self.preferredEventLoop = preferredEventLoop + + self.state.registerContinuation(responseContinuation) + } + + // MARK: Scheduled request + + func cancel() { + self.fail(HTTPClientError.cancelled) + } + + func requestWasQueued(_ scheduler: HTTPRequestScheduler) { + self.stateLock.withLock { + self.state.requestWasQueued(scheduler) + } + } + + func fail(_ error: Error) { + let action = self.stateLock.withLock { + self.state.fail(error) + } + + switch action { + case .none: + break + + case .failResponseHead(let continuation, let error, let scheduler, let executor): + continuation.resume(throwing: error) + scheduler?.cancelRequest(self) // NOTE: scheduler and executor are exclusive here + executor?.cancelRequest(self) + + case .failResponseStream(let continuation, let error, let executor): + continuation.resume(throwing: error) + executor.cancelRequest(self) + } + } + + // MARK: Scheduled request + + func willExecuteRequest(_ executor: HTTPRequestExecutor) { + let action = self.stateLock.withLock { + self.state.willExecuteRequest(executor) + } + + switch action { + case .cancel(let executor): + executor.cancelRequest(self) + case .none: + break + } + } + + func resumeRequestBodyStream() { + let action = self.stateLock.withLock { + self.state.resumeRequestBodyStream() + } + + switch action { + case .none: + break + case .resumeStream(let allocator): + switch self.requestBody?.mode { + case .asyncSequence(_, let next): + // it is safe to call this async here. it dispatches... + self.continueRequestBodyStream(allocator, next: next) + + case .byteBuffer(let byteBuffer): + self.writeOnceAndOneTimeOnly(byteBuffer: byteBuffer) + + case .none: + break + + case .sequence(_, let create): + let byteBuffer = create(allocator) + self.writeOnceAndOneTimeOnly(byteBuffer: byteBuffer) + } + } + } + + private func writeOnceAndOneTimeOnly(byteBuffer: ByteBuffer) { + // TODO: @fabianfett + let writeAction = self.stateLock.withLock { + self.state.producedNextRequestPart(byteBuffer) + } + guard case .write(let part, let executor, true) = writeAction else { + preconditionFailure("") + } + executor.writeRequestBodyPart(.byteBuffer(part), request: self) + + let finishAction = self.stateLock.withLock { + self.state.finishRequestBodyStream() + } + + guard case .forwardStreamFinished(let executor) = finishAction else { + preconditionFailure("") + } + executor.finishRequestBodyStream(self) + } + + enum AfterNextBodyPartAction { + case `continue` + case pause + } + + private func requestBodyStreamNextPart(_ part: ByteBuffer) -> AfterNextBodyPartAction { + let writeAction = self.stateLock.withLock { + self.state.producedNextRequestPart(part) + } + + switch writeAction { + case .write(let part, let executor, let continueAfter): + executor.writeRequestBodyPart(.byteBuffer(part), request: self) + if continueAfter { + return .continue + } else { + return .pause + } + + case .ignore: + // we only ignore reads, if the request has failed anyway. we should leave + // the reader loop + return .pause + } + } + + private func requestBodyStreamFinished() { + let finishAction = self.stateLock.withLock { + self.state.finishRequestBodyStream() + } + // no more data to produce + switch finishAction { + case .none: + break + case .forwardStreamFinished(let executor): + executor.finishRequestBodyStream(self) + } + return + } + + private func requestBodyStreamFailed(_ error: Error) { + let failAction = self.stateLock.withLock { + self.state.failedToProduceNextRequestPart(error) + } + + switch failAction { + case .none: + break + case .informRequestAboutFailure(let error, cancelExecutor: let executor, let continuation): + executor.cancelRequest(self) + self.fail(error) + continuation?.resume(throwing: error) + } + } + + func pauseRequestBodyStream() { + self.stateLock.withLock { + self.state.pauseRequestBodyStream() + } + } + + func receiveResponseHead(_ head: HTTPResponseHead) { + let action = self.stateLock.withLock { + self.state.receiveResponseHead(head) + } + switch action { + case .none: + break + case .succeedResponseHead(let head, let continuation): + let asyncResponse = HTTPClientResponse( + bag: self, + version: head.version, + status: head.status, + headers: head.headers + ) + continuation.resume(returning: asyncResponse) + } + } + + func receiveResponseBodyParts(_ buffer: CircularBuffer) { + let action = self.stateLock.withLock { + self.state.receiveResponseBodyParts(buffer) + } + switch action { + case .none: + break + case .succeedContinuation(let continuation, let bytes): + continuation.resume(returning: bytes) + } + } + + func succeedRequest(_ buffer: CircularBuffer?) { + let succeedAction = self.stateLock.withLock { + self.state.succeedRequest(buffer) + } + switch succeedAction { + case .succeedRequest(let continuation): + continuation.resume(returning: nil) + case .succeedContinuation(let continuation, let byteBuffer): + continuation.resume(returning: byteBuffer) + case .none: + break + } + } + + // MARK: Other methods + + private func continueRequestBodyStream( + _ allocator: ByteBufferAllocator, + next: @escaping ((ByteBufferAllocator) async throws -> ByteBuffer?) + ) { + Task { + while true { + do { + guard let part = try await next(allocator) else { // <---- dispatch point! + return self.requestBodyStreamFinished() + } + + switch self.requestBodyStreamNextPart(part) { + case .pause: + return + case .continue: + continue + } + + } catch { + // producing more failed + self.requestBodyStreamFailed(error) + return + } + } + } + } +} + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +extension AsyncRequestBag: HTTPSchedulableRequest { + var tlsConfiguration: TLSConfiguration? { + return nil + } + + var requiredEventLoop: EventLoop? { + return nil + } +} + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +extension AsyncRequestBag: HTTPExecutableRequest { + func requestHeadSent() {} +} + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +extension AsyncRequestBag { + func nextResponsePart(streamID: HTTPClientResponse.Body.IteratorStream.ID) async throws -> ByteBuffer? { + try await withUnsafeThrowingContinuation { continuation in + let action = self.stateLock.withLock { + self.state.consumeNextResponsePart(streamID: streamID, continuation: continuation) + } + switch action { + case .succeedContinuation(let continuation, let result): + continuation.resume(returning: result) + case .failContinuation(let continuation, let error): + continuation.resume(throwing: error) + case .askExecutorForMore(let executor): + executor.demandResponseBodyStream(self) + } + } + } + + func cancelResponseStream(streamID: HTTPClientResponse.Body.IteratorStream.ID) { + self.cancel() + } +} + +#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift new file mode 100644 index 000000000..b1a826449 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift @@ -0,0 +1,86 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if compiler(>=5.5) && canImport(_Concurrency) +import Logging +import NIOCore + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +extension HTTPClient { + func execute(_ request: HTTPClientRequest, deadline: NIODeadline, logger: Logger) async throws -> HTTPClientResponse { + actor SwiftCancellationHandlingIs🤔 { + enum State { + case initialized + case register(AsyncRequestBag) + case cancelled + } + + private var state: State = .initialized + + init() {} + + func registerRequestBag(_ bag: AsyncRequestBag) { + switch self.state { + case .initialized: + self.state = .register(bag) + case .cancelled: + bag.cancel() + case .register: + preconditionFailure() + } + } + + func cancel() { + switch self.state { + case .register(let bag): + self.state = .cancelled + bag.cancel() + case .cancelled: + break + case .initialized: + self.state = .cancelled + } + } + } + let preparedRequest = try HTTPClientRequest.Prepared(request) + + let cancelHandler = SwiftCancellationHandlingIs🤔() + + return try await withTaskCancellationHandler(operation: { () async throws -> HTTPClientResponse in + try await withUnsafeThrowingContinuation { + (continuation: UnsafeContinuation) -> Void in + let bag = AsyncRequestBag( + request: preparedRequest, + requestOptions: .init(idleReadTimeout: nil), + logger: logger, + connectionDeadline: .now() + .seconds(10), + preferredEventLoop: self.eventLoopGroup.next(), + responseContinuation: continuation + ) + + _Concurrency.Task { + await cancelHandler.registerRequestBag(bag) + } + + self.poolManager.executeRequest(bag) + } + }, onCancel: { + _Concurrency.Task { + await cancelHandler.cancel() + } + }) + } +} + +#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift new file mode 100644 index 000000000..dc4fc7156 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if compiler(>=5.5) && canImport(_Concurrency) +import NIOCore +import NIOHTTP1 + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +struct HTTPClientResponse { + var version: HTTPVersion + var status: HTTPResponseStatus + var headers: HTTPHeaders + var body: Body + + struct Body { + private let bag: AsyncRequestBag + + fileprivate init(_ bag: AsyncRequestBag) { + self.bag = bag + } + } + + init( + bag: AsyncRequestBag, + version: HTTPVersion, + status: HTTPResponseStatus, + headers: HTTPHeaders + ) { + self.body = .init(bag) + self.version = version + self.status = status + self.headers = headers + } +} + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +extension HTTPClientResponse.Body: AsyncSequence { + typealias Element = ByteBuffer + typealias AsyncIterator = Iterator + + struct Iterator: AsyncIteratorProtocol { + typealias Element = ByteBuffer + + private let stream: IteratorStream + + fileprivate init(stream: IteratorStream) { + self.stream = stream + } + + func next() async throws -> ByteBuffer? { + try await self.stream.next() + } + } + + func makeAsyncIterator() -> Iterator { + Iterator(stream: IteratorStream(bag: self.bag)) + } + + internal class IteratorStream { + struct ID: Hashable { + private let objectID: ObjectIdentifier + + init(_ object: IteratorStream) { + self.objectID = ObjectIdentifier(object) + } + } + + var id: ID { ID(self) } + private let bag: AsyncRequestBag + + init(bag: AsyncRequestBag) { + self.bag = bag + } + + deinit { + self.bag.cancelResponseStream(streamID: self.id) + } + + func next() async throws -> ByteBuffer? { + try await self.bag.nextResponsePart(streamID: self.id) + } + } +} + +#endif diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 09d2815f3..58fafef73 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -148,14 +148,24 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { case .sendRequestHead(let head, let startBody): if startBody { context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) - self.request!.requestHeadSent() - self.request!.resumeRequestBodyStream() + + // Writing the header might lead to errors. For this reason, we need to check, if + // the request is still present. It might have been removed, because the request was + // already failed. + if let request = self.request { + request.requestHeadSent() + request.resumeRequestBodyStream() + } + } else { context.write(self.wrapOutboundOut(.head(head)), promise: nil) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() - self.request!.requestHeadSent() + // Writing the header might lead to errors. For this reason, we need to check, if + // the request is still present. It might have been removed, because the request was + // already failed. + self.request?.requestHeadSent() if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { self.runTimeoutAction(timeoutAction, context: context) diff --git a/Tests/AsyncHTTPClientTests/AsyncRequestTests+XCTest.swift b/Tests/AsyncHTTPClientTests/AsyncRequestTests+XCTest.swift new file mode 100644 index 000000000..04c44b59f --- /dev/null +++ b/Tests/AsyncHTTPClientTests/AsyncRequestTests+XCTest.swift @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// AsyncRequestTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension AsyncRequestTests { + static var allTests: [(String, (AsyncRequestTests) -> () throws -> Void)] { + return [ + ("testCancelAsyncRequest", testCancelAsyncRequest), + ("testResponseStreamingWorks", testResponseStreamingWorks), + ("testWriteBackpressureWorks", testWriteBackpressureWorks), + ("testSimpleGetRequest", testSimpleGetRequest), + ("testBiDirectionalStreamingHTTP2", testBiDirectionalStreamingHTTP2), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/AsyncRequestTests.swift b/Tests/AsyncHTTPClientTests/AsyncRequestTests.swift new file mode 100644 index 000000000..f2e6147bc --- /dev/null +++ b/Tests/AsyncHTTPClientTests/AsyncRequestTests.swift @@ -0,0 +1,551 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIOCore +import NIOEmbedded +import NIOHTTP1 +import NIOPosix +import XCTest + +#if compiler(>=5.5) && canImport(_Concurrency) +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +typealias PreparedRequest = HTTPClientRequest.Prepared +#endif + +final class AsyncRequestTests: XCTestCase { + func testCancelAsyncRequest() { + #if compiler(>=5.5) && canImport(_Concurrency) + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + XCTAsyncTest { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + + var request = HTTPClientRequest(url: "https://localhost/") + request.method = .GET + var maybePreparedRequest: PreparedRequest? + XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) + guard let preparedRequest = maybePreparedRequest else { + return + } + let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask( + request: preparedRequest, + preferredEventLoop: embeddedEventLoop + ) + + Task.detached { + try await Task.sleep(nanoseconds: 5 * 1000 * 1000) + requestBag.cancel() + } + + do { + _ = try await responseTask.result.get() + XCTFail("Expected to throw error") + } catch { + XCTAssertEqual(error as? HTTPClientError, .cancelled) + } + } + #endif + } + + func testResponseStreamingWorks() { + #if compiler(>=5.5) && canImport(_Concurrency) + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + XCTAsyncTest { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + + var request = HTTPClientRequest(url: "https://localhost/") + request.method = .GET + + var maybePreparedRequest: PreparedRequest? + XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) + guard let preparedRequest = maybePreparedRequest else { + return + } + let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask( + request: preparedRequest, + preferredEventLoop: embeddedEventLoop + ) + + let executor = MockRequestExecutor( + pauseRequestBodyPartStreamAfterASingleWrite: true, + eventLoop: embeddedEventLoop + ) + + requestBag.willExecuteRequest(executor) + requestBag.requestHeadSent() + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) + XCTAssertFalse(executor.signalledDemandForResponseBody) + requestBag.receiveResponseHead(responseHead) + + do { + let response = try await responseTask.result.get() + XCTAssertEqual(response.status, responseHead.status) + XCTAssertEqual(response.headers, responseHead.headers) + XCTAssertEqual(response.version, responseHead.version) + + let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }.makeAsyncIterator()) + + for i in 0..<100 { + XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") + + async let part = iterator.next() + + try await Task.sleep(nanoseconds: 1000 * 1000) + + XCTAssertTrue(executor.signalledDemandForResponseBody, "Iterator caused new demand") + executor.resetDemandSignal() + requestBag.receiveResponseBodyParts([ByteBuffer(integer: i)]) + + let result = try await part + XCTAssertEqual(result, ByteBuffer(integer: i)) + } + + XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") + async let part = iterator.next() + try await Task.sleep(nanoseconds: 1000 * 1000) + XCTAssertTrue(executor.signalledDemandForResponseBody, "Iterator caused new demand") + executor.resetDemandSignal() + requestBag.succeedRequest([]) + let result = try await part + XCTAssertNil(result) + + } catch { + XCTFail("Failing tests are bad: \(error)") + } + } + #endif + } + + func testWriteBackpressureWorks() { + #if compiler(>=5.5) && canImport(_Concurrency) + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + XCTAsyncTest { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + + let streamWriter = AsyncSequenceWriter() + if await streamWriter.hasDemand { XCTFail("Did not expect to have a demand at this point") } + + var request = HTTPClientRequest(url: "https://localhost/") + request.method = .POST + request.body = .stream(streamWriter) + + var maybePreparedRequest: PreparedRequest? + XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) + guard let preparedRequest = maybePreparedRequest else { + return + } + let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask( + request: preparedRequest, + preferredEventLoop: embeddedEventLoop + ) + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + + requestBag.willExecuteRequest(executor) + requestBag.requestHeadSent() + + do { + // we need to yield here to ensure, will execute and request head can trigger + await Task.yield() + + for i in 0..<100 { + if await streamWriter.hasDemand { + XCTFail("Did not expect to have demand yet") + } + + requestBag.resumeRequestBodyStream() + try await streamWriter.demand() // wait's for the stream writer to signal demand + requestBag.pauseRequestBodyStream() + + await Task.yield() + + let part = ByteBuffer(integer: i) + await streamWriter.write(part) + + // wait for the executor to be readable again + try await executor.readable().get() + let next = executor.nextBodyPart() + XCTAssertEqual(next, .body(.byteBuffer(part))) + } + + requestBag.resumeRequestBodyStream() + try await streamWriter.demand() + + await streamWriter.end() + try await executor.readable().get() + + let next = executor.nextBodyPart() + XCTAssertEqual(next, .endOfStream) + + // write response! + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) + XCTAssertFalse(executor.signalledDemandForResponseBody) + requestBag.receiveResponseHead(responseHead) + + let response = try await responseTask.result.get() + XCTAssertEqual(response.status, responseHead.status) + XCTAssertEqual(response.headers, responseHead.headers) + XCTAssertEqual(response.version, responseHead.version) + + let iterator = SharedIterator(response.body.makeAsyncIterator()) + + XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") + async let part = iterator.next() + try await Task.sleep(nanoseconds: 1000 * 1000) + XCTAssertTrue(executor.signalledDemandForResponseBody, "Iterator caused new demand") + executor.resetDemandSignal() + requestBag.succeedRequest([]) + let result = try await part + XCTAssertNil(result) + } catch { + XCTFail("Failing tests are bad: \(error)") + } + } + #endif + } + + func testSimpleGetRequest() { + #if compiler(>=5.5) && canImport(_Concurrency) + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + XCTAsyncTest { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let httpBin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let connectionCreator = TestConnectionCreator() + let delegate = TestHTTP2ConnectionDelegate() + var maybeHTTP2Connection: HTTP2Connection? + XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + )) + guard let http2Connection = maybeHTTP2Connection else { + return XCTFail("Expected to have an HTTP2 connection here.") + } + + do { + var request = HTTPClientRequest(url: "https://localhost:\(httpBin.port)/") + request.headers = ["host": "localhost:\(httpBin.port)"] + + var maybePreparedRequest: PreparedRequest? + XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) + guard let preparedRequest = maybePreparedRequest else { + return + } + let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask( + request: preparedRequest, + preferredEventLoop: eventLoopGroup.next() + ) + + http2Connection.executeRequest(requestBag) + + XCTAssertEqual(delegate.hitStreamClosed, 0) + + let response = try await responseTask.result.get() + + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.version, .http2) + XCTAssertEqual(delegate.hitStreamClosed, 1) + + var body = try await response.body.reduce(into: ByteBuffer()) { partialResult, next in + var next = next + partialResult.writeBuffer(&next) + } + XCTAssertEqual( + try body.readJSONDecodable(RequestInfo.self, length: body.readableBytes), + RequestInfo(data: "", requestNumber: 1, connectionNumber: 0) + ) + } catch { + XCTFail("We don't like errors in tests: \(error)") + } + } + #endif + } + + func testBiDirectionalStreamingHTTP2() { + #if compiler(>=5.5) && canImport(_Concurrency) + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + XCTAsyncTest { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let httpBin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let connectionCreator = TestConnectionCreator() + let delegate = TestHTTP2ConnectionDelegate() + var maybeHTTP2Connection: HTTP2Connection? + XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + )) + guard let http2Connection = maybeHTTP2Connection else { + return XCTFail("Expected to have an HTTP2 connection here.") + } + + do { + let streamWriter = AsyncSequenceWriter() + if await streamWriter.hasDemand { XCTFail("Did not expect to have a demand at this point") } + + var request = HTTPClientRequest(url: "https://localhost:\(httpBin.port)/") + request.method = .POST + request.headers = ["host": "localhost:\(httpBin.port)"] + request.body = .stream(length: 800, streamWriter) + + var maybePreparedRequest: PreparedRequest? + XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) + guard let preparedRequest = maybePreparedRequest else { + return + } + let (requestBag, responseTask) = AsyncRequestBag.makeWithResultTask( + request: preparedRequest, + preferredEventLoop: eventLoopGroup.next() + ) + + http2Connection.executeRequest(requestBag) + + XCTAssertEqual(delegate.hitStreamClosed, 0) + + let response = try await responseTask.result.get() + + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.version, .http2) + XCTAssertEqual(delegate.hitStreamClosed, 0) + + let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }.makeAsyncIterator()) + + // at this point we can start to write to the stream and wait for the results + + for i in 0..<100 { + let buffer = ByteBuffer(integer: i) + await streamWriter.write(buffer) + var echoedBuffer = try await iterator.next() + guard let echoedInt = echoedBuffer?.readInteger(as: Int.self) else { + XCTFail("Expected to not be finished at this point") + break + } + XCTAssertEqual(i, echoedInt) + } + + XCTAssertEqual(delegate.hitStreamClosed, 0) + await streamWriter.end() + let final = try await iterator.next() + XCTAssertNil(final) + XCTAssertEqual(delegate.hitStreamClosed, 1) + + } catch { + print(error) + XCTFail("We don't like errors in tests: \(error)") + } + } + #endif + } +} + +#if compiler(>=5.5) && canImport(_Concurrency) + +// This needs a small explanation. If an iterator is a struct, it can't be used across multiple +// tasks. Since we want to wait for things to happen in tests, we need to `async let`, which creates +// implicit tasks. Therefore we need to wrap our iterator struct. +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +actor SharedIterator { + private var iterator: Iterator + + init(_ iterator: Iterator) { + self.iterator = iterator + } + + func next() async throws -> Iterator.Element? { + var iter = self.iterator + defer { self.iterator = iter } + return try await iter.next() + } +} + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +actor AsyncSequenceWriter: AsyncSequence { + typealias AsyncIterator = Iterator + typealias Element = ByteBuffer + + struct Iterator: AsyncIteratorProtocol { + typealias Element = ByteBuffer + + private let writer: AsyncSequenceWriter + + init(_ writer: AsyncSequenceWriter) { + self.writer = writer + } + + mutating func next() async throws -> ByteBuffer? { + try await self.writer.next() + } + } + + nonisolated func makeAsyncIterator() -> Iterator { + return Iterator(self) + } + + enum State { + case buffering(CircularBuffer, CheckedContinuation?) + case finished + case waiting(UnsafeContinuation) + case failed(Error) + } + + private var state: State = .buffering(.init(), nil) + + public var hasDemand: Bool { + switch self.state { + case .failed, .finished, .buffering: + return false + case .waiting: + return true + } + } + + public func demand() async throws { + switch self.state { + case .buffering(let buffer, .none): + try await withCheckedThrowingContinuation { continuation in + self.state = .buffering(buffer, continuation) + } + + case .waiting: + return + + case .buffering(_, .some): + preconditionFailure("Already waiting for demand") + + case .finished, .failed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + private func next() async throws -> ByteBuffer? { + switch self.state { + case .buffering(let buffer, let demandContinuation) where buffer.isEmpty: + return try await withUnsafeThrowingContinuation { continuation in + self.state = .waiting(continuation) + + demandContinuation?.resume(returning: ()) + } + + case .buffering(var buffer, let demandContinuation): + let first = buffer.popFirst()! + if first != nil { + self.state = .buffering(buffer, demandContinuation) + } else { + self.state = .finished + } + return first + + case .failed(let error): + self.state = .finished + throw error + + case .finished: + return nil + + case .waiting: + preconditionFailure("How can this be called twice?!") + } + } + + public func write(_ byteBuffer: ByteBuffer) { + switch self.state { + case .buffering(var buffer, let continuation): + buffer.append(byteBuffer) + self.state = .buffering(buffer, continuation) + + case .waiting(let continuation): + self.state = .buffering(.init(), nil) + continuation.resume(returning: byteBuffer) + + case .finished, .failed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + public func end() { + switch self.state { + case .buffering(var buffer, let continuation): + buffer.append(nil) + self.state = .buffering(buffer, continuation) + + case .waiting(let continuation): + self.state = .finished + continuation.resume(returning: nil) + + case .finished, .failed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + public func fail(_ error: Error) { + switch self.state { + case .buffering: + self.state = .failed(error) + + case .failed, .finished: + return + + case .waiting(let continuation): + self.state = .finished + continuation.resume(throwing: error) + } + } +} + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +extension AsyncRequestBag { + fileprivate static func makeWithResultTask( + request: PreparedRequest, + requestOptions: RequestOptions = .forTests(), + logger: Logger = Logger(label: "test"), + connectionDeadline: NIODeadline = .distantFuture, + preferredEventLoop: EventLoop + ) -> (AsyncRequestBag, _Concurrency.Task) { + let requestBagPromise = preferredEventLoop.makePromise(of: AsyncRequestBag.self) + let result = Task { + try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in + let requestBag = AsyncRequestBag( + request: request, + requestOptions: requestOptions, + logger: logger, + connectionDeadline: connectionDeadline, + preferredEventLoop: preferredEventLoop, + responseContinuation: continuation + ) + requestBagPromise.succeed(requestBag) + } + } + // the promise can never fail and it is therefore safe to force unwrap + let requestBag = try! requestBagPromise.futureResult.wait() + + return (requestBag, result) + } +} +#endif diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index e71490bd1..6a1b78e4a 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -641,7 +641,7 @@ internal struct HTTPResponseBuilder { } } -internal struct RequestInfo: Codable { +internal struct RequestInfo: Codable, Equatable { var data: String var requestNumber: Int var connectionNumber: Int diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index 1acde7732..a574944a0 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -14,6 +14,7 @@ @testable import AsyncHTTPClient import Logging +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOHTTP1 @@ -70,7 +71,10 @@ final class RequestBagTests: XCTestCase { XCTAssert(bag.task.eventLoop === embeddedEventLoop) - let executor = MockRequestExecutor(pauseRequestBodyPartStreamAfterASingleWrite: true) + let executor = MockRequestExecutor( + pauseRequestBodyPartStreamAfterASingleWrite: true, + eventLoop: embeddedEventLoop + ) bag.willExecuteRequest(executor) @@ -99,7 +103,7 @@ final class RequestBagTests: XCTestCase { streamIsAllowedToWrite = true bag.resumeRequestBodyStream() streamIsAllowedToWrite = false - XCTAssertLessThanOrEqual(executor.requestBodyParts.count, 2) + XCTAssertLessThanOrEqual(executor.requestBodyPartsCount, 2) XCTAssertEqual(delegate.hitDidSendRequestPart, writes) } } @@ -110,6 +114,7 @@ final class RequestBagTests: XCTestCase { let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: .init([ ("Transfer-Encoding", "chunked"), ])) + XCTAssertFalse(executor.signalledDemandForResponseBody) bag.receiveResponseHead(responseHead) XCTAssertEqual(responseHead, delegate.receivedHead) XCTAssertNoThrow(try XCTUnwrap(delegate.backpressurePromise).succeed(())) @@ -178,7 +183,7 @@ final class RequestBagTests: XCTestCase { guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.task.eventLoop === embeddedEventLoop) - let executor = MockRequestExecutor() + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) bag.willExecuteRequest(executor) @@ -221,7 +226,7 @@ final class RequestBagTests: XCTestCase { guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) - let executor = MockRequestExecutor() + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) bag.cancel() bag.willExecuteRequest(executor) @@ -254,7 +259,7 @@ final class RequestBagTests: XCTestCase { guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) - let executor = MockRequestExecutor() + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) bag.willExecuteRequest(executor) XCTAssertFalse(executor.isCancelled) @@ -329,7 +334,7 @@ final class RequestBagTests: XCTestCase { )) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } - let executor = MockRequestExecutor() + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) bag.willExecuteRequest(executor) bag.requestHeadSent() bag.receiveResponseHead(.init(version: .http1_1, status: .ok)) @@ -386,7 +391,7 @@ final class RequestBagTests: XCTestCase { )) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } - let executor = MockRequestExecutor() + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) bag.willExecuteRequest(executor) XCTAssertEqual(delegate.hitDidSendRequestHead, 0) @@ -431,7 +436,7 @@ final class RequestBagTests: XCTestCase { )) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } - let executor = MockRequestExecutor() + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) bag.willExecuteRequest(executor) bag.requestHeadSent() bag.receiveResponseHead(.init(version: .http1_1, status: .ok)) @@ -472,45 +477,98 @@ class MockRequestExecutor: HTTPRequestExecutor { case endOfStream } + let eventLoop: EventLoop + let lock = Lock() let pauseRequestBodyPartStreamAfterASingleWrite: Bool - private(set) var requestBodyParts = CircularBuffer() - private(set) var isCancelled: Bool = false - private(set) var signalledDemandForResponseBody: Bool = false + var isCancelled: Bool { + self.lock.withLock { self._isCancelled } + } + + var signalledDemandForResponseBody: Bool { + self.lock.withLock { self._signaledDemandForResponseBody } + } + + var requestBodyPartsCount: Int { + self.lock.withLock { self._requestBodyParts.count } + } + + private(set) var _requestBodyParts = CircularBuffer() + private(set) var _isCancelled: Bool = false + private(set) var _signaledDemandForResponseBody: Bool = false + private(set) var _readable: EventLoopPromise? - init(pauseRequestBodyPartStreamAfterASingleWrite: Bool = false) { + init(pauseRequestBodyPartStreamAfterASingleWrite: Bool = false, eventLoop: EventLoop) { self.pauseRequestBodyPartStreamAfterASingleWrite = pauseRequestBodyPartStreamAfterASingleWrite + self.eventLoop = eventLoop } func nextBodyPart() -> RequestParts? { - guard !self.requestBodyParts.isEmpty else { return nil } - return self.requestBodyParts.removeFirst() + self.lock.withLock { () -> RequestParts? in + guard !self._requestBodyParts.isEmpty else { return nil } + return self._requestBodyParts.removeFirst() + } } func resetDemandSignal() { - self.signalledDemandForResponseBody = false + self.lock.withLockVoid { + self._signaledDemandForResponseBody = false + } + } + + func readable() -> EventLoopFuture { + self.lock.withLock { () -> EventLoopFuture in + if !self._requestBodyParts.isEmpty { + return self.eventLoop.makeSucceededVoidFuture() + } + + let promise = self.eventLoop.makePromise(of: Void.self) + self._readable = promise + return promise.futureResult + } } // this should always be called twice. When we receive the first call, the next call to produce // data is already scheduled. If we call pause here, once, after the second call new subsequent // calls should not be scheduled. func writeRequestBodyPart(_ part: IOData, request: HTTPExecutableRequest) { - if self.requestBodyParts.isEmpty, self.pauseRequestBodyPartStreamAfterASingleWrite { + let (pause, promise) = self.lock.withLock { () -> (Bool, EventLoopPromise?) in + var pause = false + if self._requestBodyParts.isEmpty, self.pauseRequestBodyPartStreamAfterASingleWrite { + pause = true + } + self._requestBodyParts.append(.body(part)) + let promise = self._readable + self._readable = nil + return (pause, promise) + } + + if pause { request.pauseRequestBodyStream() } - self.requestBodyParts.append(.body(part)) + + promise?.succeed(()) } func finishRequestBodyStream(_: HTTPExecutableRequest) { - self.requestBodyParts.append(.endOfStream) + let promise = self.lock.withLock { () -> EventLoopPromise? in + self._requestBodyParts.append(.endOfStream) + let promise = self._readable + self._readable = nil + return promise + } + + promise?.succeed(()) } func demandResponseBodyStream(_: HTTPExecutableRequest) { - self.signalledDemandForResponseBody = true + self.lock.withLockVoid { + self._signaledDemandForResponseBody = true + } } func cancelRequest(_: HTTPExecutableRequest) { - self.isCancelled = true + self._isCancelled = true } } diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 44d15b401..7a8c78363 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -26,6 +26,7 @@ import XCTest @testable import AsyncHTTPClientTests XCTMain([ + testCase(AsyncRequestTests.allTests), testCase(HTTP1ClientChannelHandlerTests.allTests), testCase(HTTP1ConnectionStateMachineTests.allTests), testCase(HTTP1ConnectionTests.allTests),