diff --git a/.github/release.yml b/.github/release.yml index 13c29b0e6..e29eb8464 100644 --- a/.github/release.yml +++ b/.github/release.yml @@ -2,13 +2,13 @@ changelog: categories: - title: SemVer Major labels: - - semver/major + - ⚠️ semver/major - title: SemVer Minor labels: - - semver/minor + - 🆕 semver/minor - title: SemVer Patch labels: - - semver/patch + - 🔨 semver/patch - title: Other Changes labels: - semver/none diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6e5453369..3bf5a95ec 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -11,8 +11,13 @@ jobs: name: Unit tests uses: apple/swift-nio/.github/workflows/unit_tests.yml@main with: - linux_5_9_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" - linux_5_10_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" - linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" - linux_nightly_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" - linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_5_10_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_1_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_nightly_next_arguments_override: "--explicit-target-dependency-import-check error" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error" + + static-sdk: + name: Static SDK + # Workaround https://github.com/nektos/act/issues/1875 + uses: apple/swift-nio/.github/workflows/static_sdk.yml@main diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 0392cb7c5..8036d7ad7 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -14,12 +14,19 @@ jobs: name: Unit tests uses: apple/swift-nio/.github/workflows/unit_tests.yml@main with: - linux_5_9_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" - linux_5_10_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" - linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" - linux_nightly_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" - linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_5_10_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_1_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_nightly_next_arguments_override: "--explicit-target-dependency-import-check error" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error" cxx-interop: name: Cxx interop uses: apple/swift-nio/.github/workflows/cxx_interop.yml@main + with: + linux_5_9_enabled: false + + static-sdk: + name: Static SDK + # Workaround https://github.com/nektos/act/issues/1875 + uses: apple/swift-nio/.github/workflows/static_sdk.yml@main diff --git a/Package.swift b/Package.swift index e4cccb6de..3294781a9 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.8 +// swift-tools-version:5.10 //===----------------------------------------------------------------------===// // // This source file is part of the AsyncHTTPClient open source project @@ -15,18 +15,36 @@ import PackageDescription +let strictConcurrencyDevelopment = false + +let strictConcurrencySettings: [SwiftSetting] = { + var initialSettings: [SwiftSetting] = [] + initialSettings.append(contentsOf: [ + .enableUpcomingFeature("StrictConcurrency"), + .enableUpcomingFeature("InferSendableFromCaptures"), + ]) + + if strictConcurrencyDevelopment { + // -warnings-as-errors here is a workaround so that IDE-based development can + // get tripped up on -require-explicit-sendable. + initialSettings.append(.unsafeFlags(["-Xfrontend", "-require-explicit-sendable", "-warnings-as-errors"])) + } + + return initialSettings +}() + let package = Package( name: "async-http-client", products: [ .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]) ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.71.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.27.1"), - .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.19.0"), - .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.13.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.4.4"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.30.0"), + .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.36.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.26.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.24.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.6.0"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), .package(url: "https://github.com/apple/swift-algorithms.git", from: "1.0.0"), ], @@ -55,7 +73,8 @@ let package = Package( .product(name: "Logging", package: "swift-log"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Algorithms", package: "swift-algorithms"), - ] + ], + swiftSettings: strictConcurrencySettings ), .testTarget( name: "AsyncHTTPClientTests", @@ -79,18 +98,24 @@ let package = Package( .copy("Resources/self_signed_key.pem"), .copy("Resources/example.com.cert.pem"), .copy("Resources/example.com.private-key.pem"), - ] + ], + swiftSettings: strictConcurrencySettings ), ] ) // --- STANDARD CROSS-REPO SETTINGS DO NOT EDIT --- // for target in package.targets { - if target.type != .plugin { + switch target.type { + case .regular, .test, .executable: var settings = target.swiftSettings ?? [] // https://github.com/swiftlang/swift-evolution/blob/main/proposals/0444-member-import-visibility.md settings.append(.enableUpcomingFeature("MemberImportVisibility")) target.swiftSettings = settings + case .macro, .plugin, .system, .binary: + () // not applicable + @unknown default: + () // we don't know what to do here, do nothing } } // --- END: STANDARD CROSS-REPO SETTINGS DO NOT EDIT --- // diff --git a/README.md b/README.md index 871eb910b..a4f49c8c8 100644 --- a/README.md +++ b/README.md @@ -306,7 +306,7 @@ Please have a look at [SECURITY.md](SECURITY.md) for AsyncHTTPClient's security ## Supported Versions -The most recent versions of AsyncHTTPClient support Swift 5.6 and newer. The minimum Swift version supported by AsyncHTTPClient releases are detailed below: +The most recent versions of AsyncHTTPClient support Swift 5.10 and newer. The minimum Swift version supported by AsyncHTTPClient releases are detailed below: AsyncHTTPClient | Minimum Swift Version --------------------|---------------------- @@ -316,4 +316,6 @@ AsyncHTTPClient | Minimum Swift Version `1.13.0 ..< 1.18.0` | 5.5.2 `1.18.0 ..< 1.20.0` | 5.6 `1.20.0 ..< 1.21.0` | 5.7 -`1.21.0 ...` | 5.8 +`1.21.0 ..< 1.26.0` | 5.8 +`1.26.0 ..< 1.27.0` | 5.9 +`1.27.0 ...` | 5.10 diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift index 8f6b32bd2..fbcc82ec1 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift @@ -46,3 +46,6 @@ struct AnyAsyncSequence: Sendable, AsyncSequence { .init(nextCallback: self.makeAsyncIteratorCallback()) } } + +@available(*, unavailable) +extension AnyAsyncSequence.AsyncIterator: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift index fc1dbc209..5fc1be9f5 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift @@ -26,6 +26,10 @@ extension HTTPClient { /// - request: HTTP request to execute. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. + /// + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. + /// /// - Returns: The response to the request. Note that the `body` of the response may not yet have been fully received. public func execute( _ request: HTTPClientRequest, @@ -51,6 +55,10 @@ extension HTTPClient { /// - request: HTTP request to execute. /// - timeout: time the the request has to complete. /// - logger: The logger to use for this request. + /// + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. + /// /// - Returns: The response to the request. Note that the `body` of the response may not yet have been fully received. public func execute( _ request: HTTPClientRequest, @@ -67,6 +75,8 @@ extension HTTPClient { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClient { + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. private func executeAndFollowRedirectsIfNeeded( _ request: HTTPClientRequest, deadline: NIODeadline, @@ -75,11 +85,29 @@ extension HTTPClient { ) async throws -> HTTPClientResponse { var currentRequest = request var currentRedirectState = redirectState + var history: [HTTPClientRequestResponse] = [] // this loop is there to follow potential redirects while true { let preparedRequest = try HTTPClientRequest.Prepared(currentRequest, dnsOverride: configuration.dnsOverride) - let response = try await self.executeCancellable(preparedRequest, deadline: deadline, logger: logger) + let response = try await { + var response = try await self.executeCancellable(preparedRequest, deadline: deadline, logger: logger) + + history.append( + .init( + request: currentRequest, + responseHead: .init( + version: response.version, + status: response.status, + headers: response.headers + ) + ) + ) + + response.history = history + + return response + }() guard var redirectState = currentRedirectState else { // a `nil` redirectState means we should not follow redirects @@ -116,6 +144,8 @@ extension HTTPClient { } } + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. private func executeCancellable( _ request: HTTPClientRequest.Prepared, deadline: NIODeadline, diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift index d4eeae03e..c39452897 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift @@ -24,7 +24,7 @@ extension HTTPClientRequest { enum Body { case asyncSequence( length: RequestBodyLength, - nextBodyPart: (ByteBufferAllocator) async throws -> ByteBuffer? + makeAsyncIterator: @Sendable () -> ((ByteBufferAllocator) async throws -> ByteBuffer?) ) case sequence( length: RequestBodyLength, @@ -80,7 +80,7 @@ extension HTTPClientRequest.Prepared.Body { init(_ body: HTTPClientRequest.Body) { switch body.mode { case .asyncSequence(let length, let makeAsyncIterator): - self = .asyncSequence(length: length, nextBodyPart: makeAsyncIterator()) + self = .asyncSequence(length: length, makeAsyncIterator: makeAsyncIterator) case .sequence(let length, let canBeConsumedMultipleTimes, let makeCompleteBody): self = .sequence( length: length, diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift index f07a2ed41..dca7de0ef 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift @@ -421,3 +421,9 @@ extension HTTPClientRequest.Body { } } } + +@available(*, unavailable) +extension HTTPClientRequest.Body.AsyncIterator: Sendable {} + +@available(*, unavailable) +extension HTTPClientRequest.Body.AsyncIterator.Storage: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift index 832eb7b41..36c1cb36f 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift @@ -15,6 +15,8 @@ import NIOCore import NIOHTTP1 +import struct Foundation.URL + /// A representation of an HTTP response for the Swift Concurrency HTTPClient API. /// /// This object is similar to ``HTTPClient/Response``, but used for the Swift Concurrency API. @@ -32,6 +34,18 @@ public struct HTTPClientResponse: Sendable { /// The body of this HTTP response. public var body: Body + /// The history of all requests and responses in redirect order. + public var history: [HTTPClientRequestResponse] + + /// The target URL (after redirects) of the response. + public var url: URL? { + guard let lastRequestURL = self.history.last?.request.url else { + return nil + } + + return URL(string: lastRequestURL) + } + @inlinable public init( version: HTTPVersion = .http1_1, status: HTTPResponseStatus = .ok, @@ -42,6 +56,21 @@ public struct HTTPClientResponse: Sendable { self.status = status self.headers = headers self.body = body + self.history = [] + } + + @inlinable public init( + version: HTTPVersion = .http1_1, + status: HTTPResponseStatus = .ok, + headers: HTTPHeaders = [:], + body: Body = Body(), + history: [HTTPClientRequestResponse] = [] + ) { + self.version = version + self.status = status + self.headers = headers + self.body = body + self.history = history } init( @@ -49,7 +78,8 @@ public struct HTTPClientResponse: Sendable { version: HTTPVersion, status: HTTPResponseStatus, headers: HTTPHeaders, - body: TransactionBody + body: TransactionBody, + history: [HTTPClientRequestResponse] ) { self.init( version: version, @@ -64,11 +94,23 @@ public struct HTTPClientResponse: Sendable { status: status ) ) - ) + ), + history: history ) } } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +public struct HTTPClientRequestResponse: Sendable { + public var request: HTTPClientRequest + public var responseHead: HTTPResponseHead + + public init(request: HTTPClientRequest, responseHead: HTTPResponseHead) { + self.request = request + self.responseHead = responseHead + } +} + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientResponse { /// A representation of the response body for an HTTP response. @@ -217,3 +259,9 @@ extension HTTPClientResponse.Body { .stream(CollectionOfOne(byteBuffer).async) } } + +@available(*, unavailable) +extension HTTPClientResponse.Body.AsyncIterator: Sendable {} + +@available(*, unavailable) +extension HTTPClientResponse.Body.Storage.AsyncIterator: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift index 6cf0dbc07..457627a8a 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift @@ -552,3 +552,6 @@ extension Transaction { } } } + +@available(*, unavailable) +extension Transaction.StateMachine: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index 408ebeeb6..6bf8b38b7 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -77,9 +77,11 @@ final class Transaction: private func continueRequestBodyStream( _ allocator: ByteBufferAllocator, - next: @escaping ((ByteBufferAllocator) async throws -> ByteBuffer?) + makeAsyncIterator: @Sendable @escaping () -> ((ByteBufferAllocator) async throws -> ByteBuffer?) ) { Task { + let next = makeAsyncIterator() + do { while let part = try await next(allocator) { do { @@ -199,9 +201,9 @@ extension Transaction: HTTPExecutableRequest { case .startStream(let allocator): switch self.request.body { - case .asyncSequence(_, let next): + case .asyncSequence(_, let makeAsyncIterator): // it is safe to call this async here. it dispatches... - self.continueRequestBodyStream(allocator, next: next) + self.continueRequestBodyStream(allocator, makeAsyncIterator: makeAsyncIterator) case .byteBuffer(let byteBuffer): self.writeOnceAndOneTimeOnly(byteBuffer: byteBuffer) @@ -242,7 +244,8 @@ extension Transaction: HTTPExecutableRequest { version: head.version, status: head.status, headers: head.headers, - body: body + body: body, + history: [] ) continuation.resume(returning: response) } diff --git a/Sources/AsyncHTTPClient/Base64.swift b/Sources/AsyncHTTPClient/Base64.swift index 3162e7251..4d2ddcc49 100644 --- a/Sources/AsyncHTTPClient/Base64.swift +++ b/Sources/AsyncHTTPClient/Base64.swift @@ -29,7 +29,7 @@ extension String { // swift-format-ignore: DontRepeatTypeInStaticProperties @usableFromInline -internal struct Base64 { +internal struct Base64: Sendable { @inlinable static func encode( diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index 35f7a21c4..b5b058c2e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -20,7 +20,9 @@ import NIOSSL import Darwin.C #elseif canImport(Musl) import Musl -#elseif os(Linux) || os(FreeBSD) || os(Android) +#elseif canImport(Android) +import Android +#elseif os(Linux) || os(FreeBSD) import Glibc #else #error("unsupported target operating system") diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift index db7b7b7ef..1636fe379 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift @@ -137,7 +137,7 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand return } - let timeout = context.eventLoop.scheduleTask(deadline: self.deadline) { + let timeout = context.eventLoop.assumeIsolated().scheduleTask(deadline: self.deadline) { switch self.state { case .initialized: preconditionFailure("How can we have a scheduled timeout, if the connection is not even up?") diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift index a98f97d4d..7458627fd 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift @@ -99,7 +99,7 @@ final class SOCKSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { return } - let scheduled = context.eventLoop.scheduleTask(deadline: self.deadline) { + let scheduled = context.eventLoop.assumeIsolated().scheduleTask(deadline: self.deadline) { switch self.state { case .initialized, .channelActive: // close the connection, if the handshake timed out diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift index bebd0bcc7..d210b2747 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift @@ -104,7 +104,7 @@ final class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { var scheduled: Scheduled? if let deadline = deadline { - scheduled = context.eventLoop.scheduleTask(deadline: deadline) { + scheduled = context.eventLoop.assumeIsolated().scheduleTask(deadline: deadline) { switch self.state { case .initialized, .channelActive: // close the connection, if the handshake timed out diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 74a0c72d7..191517c71 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -185,7 +185,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.runTimeoutAction(timeoutAction, context: context) } - req.willExecuteRequest(self) + req.willExecuteRequest(self.requestExecutor) let action = self.state.runNewRequest( head: req.requestHead, @@ -314,6 +314,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { let oldRequest = self.request! self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { case .close: @@ -322,7 +323,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { case .sendRequestEnd(let writePromise, let shouldClose): let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) // We need to defer succeeding the old request to avoid ordering issues - writePromise.futureResult.hop(to: context.eventLoop).whenComplete { result in + writePromise.futureResult.hop(to: context.eventLoop).assumeIsolated().whenComplete { result in switch result { case .success: // If our final action was `sendRequestEnd`, that means we've already received @@ -353,6 +354,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { let oldRequest = self.request! self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { case .close(let writePromise): @@ -394,7 +396,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -407,7 +409,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.currentIdleReadTimeoutTimerID &+= 1 let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -429,7 +431,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleWriteTimeoutTimerID - self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleWriteTimeoutTimerID == timerID else { return } let action = self.state.idleWriteTimeoutTriggered() self.run(action, context: context) @@ -441,7 +443,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.currentIdleWriteTimeoutTimerID &+= 1 let timerID = self.currentIdleWriteTimeoutTimerID - self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleWriteTimeoutTimerID == timerID else { return } let action = self.state.idleWriteTimeoutTriggered() self.run(action, context: context) @@ -459,8 +461,11 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) - { + fileprivate func writeRequestBodyPart0( + _ data: IOData, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after @@ -479,7 +484,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + fileprivate func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` promise?.fail(HTTPClientError.requestStreamCancelled) @@ -490,7 +495,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func demandResponseBodyStream0(_ request: HTTPExecutableRequest) { + fileprivate func demandResponseBodyStream0(_ request: HTTPExecutableRequest) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return @@ -502,7 +507,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func cancelRequest0(_ request: HTTPExecutableRequest) { + fileprivate func cancelRequest0(_ request: HTTPExecutableRequest) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return @@ -522,43 +527,39 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { @available(*, unavailable) extension HTTP1ClientChannelHandler: Sendable {} -extension HTTP1ClientChannelHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { - if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request, promise: promise) - } else { - self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request, promise: promise) +extension HTTP1ClientChannelHandler { + var requestExecutor: RequestExecutor { + RequestExecutor(self) + } + + struct RequestExecutor: HTTPRequestExecutor, Sendable { + private let loopBound: NIOLoopBound + + init(_ handler: HTTP1ClientChannelHandler) { + self.loopBound = NIOLoopBound(handler, eventLoop: handler.eventLoop) + } + + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.writeRequestBodyPart0(data, request: request, promise: promise) } } - } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { - if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request, promise: promise) - } else { - self.eventLoop.execute { - self.finishRequestBodyStream0(request, promise: promise) + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.finishRequestBodyStream0(request, promise: promise) } } - } - func demandResponseBodyStream(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.demandResponseBodyStream0(request) - } else { - self.eventLoop.execute { - self.demandResponseBodyStream0(request) + func demandResponseBodyStream(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.demandResponseBodyStream0(request) } } - } - func cancelRequest(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.cancelRequest0(request) - } else { - self.eventLoop.execute { - self.cancelRequest0(request) + func cancelRequest(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.cancelRequest0(request) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift index e0496f2e3..6f64e0407 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift @@ -17,9 +17,9 @@ import NIOCore import NIOHTTP1 import NIOHTTPCompression -protocol HTTP1ConnectionDelegate { - func http1ConnectionReleased(_: HTTP1Connection) - func http1ConnectionClosed(_: HTTP1Connection) +protocol HTTP1ConnectionDelegate: Sendable { + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) } final class HTTP1Connection { @@ -67,32 +67,45 @@ final class HTTP1Connection { return connection } - func executeRequest(_ request: HTTPExecutableRequest) { - if self.channel.eventLoop.inEventLoop { - self.execute0(request: request) - } else { - self.channel.eventLoop.execute { - self.execute0(request: request) + var sendableView: SendableView { + SendableView(self) + } + + struct SendableView: Sendable { + private let connection: NIOLoopBound + let channel: Channel + let id: HTTPConnectionPool.Connection.ID + private var eventLoop: EventLoop { self.connection.eventLoop } + + init(_ connection: HTTP1Connection) { + self.connection = NIOLoopBound(connection, eventLoop: connection.channel.eventLoop) + self.id = connection.id + self.channel = connection.channel + } + + func executeRequest(_ request: HTTPExecutableRequest) { + self.connection.execute { + $0.execute0(request: request) } } - } - func shutdown() { - self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) - } + func shutdown() { + self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + } - func close(promise: EventLoopPromise?) { - self.channel.close(mode: .all, promise: promise) - } + func close(promise: EventLoopPromise?) { + self.channel.close(mode: .all, promise: promise) + } - func close() -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: Void.self) - self.close(promise: promise) - return promise.futureResult + func close() -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Void.self) + self.close(promise: promise) + return promise.futureResult + } } func taskCompleted() { - self.delegate.http1ConnectionReleased(self) + self.delegate.http1ConnectionReleased(self.id) } private func execute0(request: HTTPExecutableRequest) { @@ -100,7 +113,7 @@ final class HTTP1Connection { return request.fail(ChannelError.ioOnClosedChannel) } - self.channel.write(request, promise: nil) + self.channel.pipeline.syncOperations.write(NIOAny(request), promise: nil) } private func start(decompression: HTTPClient.Decompression, logger: Logger) throws { @@ -111,9 +124,9 @@ final class HTTP1Connection { } self.state = .active - self.channel.closeFuture.whenComplete { _ in + self.channel.closeFuture.assumeIsolated().whenComplete { _ in self.state = .closed - self.delegate.http1ConnectionClosed(self) + self.delegate.http1ConnectionClosed(self.id) } do { @@ -150,3 +163,6 @@ final class HTTP1Connection { } } } + +@available(*, unavailable) +extension HTTP1Connection: Sendable {} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index aee0736ff..2cde1df3f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -359,7 +359,7 @@ struct HTTP1ConnectionStateMachine { mutating func idleWriteTimeoutTriggered() -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + return .wait } return self.avoidingStateMachineCoW { state -> Action in diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 5e105c0d8..7c0197cdf 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -137,7 +137,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.runTimeoutAction(timeoutAction, context: context) } - request.willExecuteRequest(self) + request.willExecuteRequest(self.requestExecutor) let action = self.state.startRequest( head: request.requestHead, @@ -240,6 +240,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.request!.fail(error) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) // No matter the error reason, we must always make sure the h2 stream is closed. Only // once the h2 stream is closed, it is released from the h2 multiplexer. The // HTTPRequestStateMachine may signal finalAction: .none in the error case (as this is @@ -252,6 +253,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.request!.succeedRequest(finalParts) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) self.runSuccessfulFinalAction(finalAction, context: context) case .failSendBodyPart(let error, let writePromise), .failSendStreamFinished(let error, let writePromise): @@ -311,7 +313,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -324,7 +326,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.currentIdleReadTimeoutTimerID &+= 1 let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -347,7 +349,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleWriteTimeoutTimerID - self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleWriteTimeoutTimerID == timerID else { return } let action = self.state.idleWriteTimeoutTriggered() self.run(action, context: context) @@ -359,7 +361,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.currentIdleWriteTimeoutTimerID &+= 1 let timerID = self.currentIdleWriteTimeoutTimerID - self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleWriteTimeoutTimerID == timerID else { return } let action = self.state.idleWriteTimeoutTriggered() self.run(action, context: context) @@ -435,43 +437,39 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { @available(*, unavailable) extension HTTP2ClientRequestHandler: Sendable {} -extension HTTP2ClientRequestHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { - if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request, promise: promise) - } else { - self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request, promise: promise) +extension HTTP2ClientRequestHandler { + var requestExecutor: RequestExecutor { + RequestExecutor(self) + } + + struct RequestExecutor: HTTPRequestExecutor, Sendable { + private let loopBound: NIOLoopBound + + init(_ handler: HTTP2ClientRequestHandler) { + self.loopBound = NIOLoopBound(handler, eventLoop: handler.eventLoop) + } + + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.writeRequestBodyPart0(data, request: request, promise: promise) } } - } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { - if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request, promise: promise) - } else { - self.eventLoop.execute { - self.finishRequestBodyStream0(request, promise: promise) + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.finishRequestBodyStream0(request, promise: promise) } } - } - func demandResponseBodyStream(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.demandResponseBodyStream0(request) - } else { - self.eventLoop.execute { - self.demandResponseBodyStream0(request) + func demandResponseBodyStream(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.demandResponseBodyStream0(request) } } - } - func cancelRequest(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.cancelRequest0(request) - } else { - self.eventLoop.execute { - self.cancelRequest0(request) + func cancelRequest(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.cancelRequest0(request) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift index 5e4ae6e01..1c24554e2 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift @@ -17,11 +17,11 @@ import NIOCore import NIOHTTP2 import NIOHTTPCompression -protocol HTTP2ConnectionDelegate { - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) - func http2ConnectionGoAwayReceived(_: HTTP2Connection) - func http2ConnectionClosed(_: HTTP2Connection) +protocol HTTP2ConnectionDelegate: Sendable { + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) } struct HTTP2PushNotSupportedError: Error {} @@ -35,6 +35,9 @@ final class HTTP2Connection { let multiplexer: HTTP2StreamMultiplexer let logger: Logger + /// A method with access to the stream channel that is called when creating the stream. + let streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + /// the connection pool that created the connection let delegate: HTTP2ConnectionDelegate @@ -95,7 +98,8 @@ final class HTTP2Connection { decompression: HTTPClient.Decompression, maximumConnectionUses: Int?, delegate: HTTP2ConnectionDelegate, - logger: Logger + logger: Logger, + streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil ) { self.channel = channel self.id = connectionID @@ -114,6 +118,7 @@ final class HTTP2Connection { ) self.delegate = delegate self.state = .initialized + self.streamChannelDebugInitializer = streamChannelDebugInitializer } deinit { @@ -128,49 +133,72 @@ final class HTTP2Connection { delegate: HTTP2ConnectionDelegate, decompression: HTTPClient.Decompression, maximumConnectionUses: Int?, - logger: Logger - ) -> EventLoopFuture<(HTTP2Connection, Int)> { + logger: Logger, + streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) -> EventLoopFuture<(HTTP2Connection, Int)>.Isolated { let connection = HTTP2Connection( channel: channel, connectionID: connectionID, decompression: decompression, maximumConnectionUses: maximumConnectionUses, delegate: delegate, - logger: logger + logger: logger, + streamChannelDebugInitializer: streamChannelDebugInitializer ) - return connection._start0().map { maxStreams in (connection, maxStreams) } + + return connection._start0().assumeIsolated().map { maxStreams in + (connection, maxStreams) + } } - func executeRequest(_ request: HTTPExecutableRequest) { - if self.channel.eventLoop.inEventLoop { - self.executeRequest0(request) - } else { - self.channel.eventLoop.execute { - self.executeRequest0(request) + var sendableView: SendableView { + SendableView(self) + } + + struct SendableView: Sendable { + private let connection: NIOLoopBound + let id: HTTPConnectionPool.Connection.ID + let channel: Channel + + var eventLoop: EventLoop { + self.connection.eventLoop + } + + var closeFuture: EventLoopFuture { + self.channel.closeFuture + } + + func __forTesting_getStreamChannels() -> [Channel] { + self.connection.value.__forTesting_getStreamChannels() + } + + init(_ connection: HTTP2Connection) { + self.connection = NIOLoopBound(connection, eventLoop: connection.channel.eventLoop) + self.id = connection.id + self.channel = connection.channel + } + + func executeRequest(_ request: HTTPExecutableRequest) { + self.connection.execute { + $0.executeRequest0(request) } } - } - /// shuts down the connection by cancelling all running tasks and closing the connection once - /// all child streams/channels are closed. - func shutdown() { - if self.channel.eventLoop.inEventLoop { - self.shutdown0() - } else { - self.channel.eventLoop.execute { - self.shutdown0() + func shutdown() { + self.connection.execute { + $0.shutdown0() } } - } - func close(promise: EventLoopPromise?) { - self.channel.close(mode: .all, promise: promise) - } + func close(promise: EventLoopPromise?) { + self.channel.close(mode: .all, promise: promise) + } - func close() -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: Void.self) - self.close(promise: promise) - return promise.futureResult + func close() -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Void.self) + self.close(promise: promise) + return promise.futureResult + } } func _start0() -> EventLoopFuture { @@ -179,7 +207,7 @@ final class HTTP2Connection { let readyToAcceptConnectionsPromise = self.channel.eventLoop.makePromise(of: Int.self) self.state = .starting(readyToAcceptConnectionsPromise) - self.channel.closeFuture.whenComplete { _ in + self.channel.closeFuture.assumeIsolated().whenComplete { _ in switch self.state { case .initialized, .closed: preconditionFailure("invalid state \(self.state)") @@ -188,7 +216,7 @@ final class HTTP2Connection { readyToAcceptConnectionsPromise.fail(HTTPClientError.remoteConnectionClosed) case .active, .closing: self.state = .closed - self.delegate.http2ConnectionClosed(self) + self.delegate.http2ConnectionClosed(self.id) } } @@ -227,13 +255,18 @@ final class HTTP2Connection { case .active: let createStreamChannelPromise = self.channel.eventLoop.makePromise(of: Channel.self) - self.multiplexer.createStreamChannel(promise: createStreamChannelPromise) { - channel -> EventLoopFuture in + let loopBoundSelf = NIOLoopBound(self, eventLoop: self.channel.eventLoop) + + self.multiplexer.createStreamChannel( + promise: createStreamChannelPromise + ) { [streamChannelDebugInitializer] channel -> EventLoopFuture in + let connection = loopBoundSelf.value + do { // the connection may have been asked to shutdown while we created the child. in // this // channel. - guard case .active = self.state else { + guard case .active = connection.state else { throw HTTPClientError.cancelled } @@ -242,7 +275,7 @@ final class HTTP2Connection { let translate = HTTP2FramePayloadToHTTP1ClientCodec(httpProtocol: .https) try channel.pipeline.syncOperations.addHandler(translate) - if case .enabled(let limit) = self.decompression { + if case .enabled(let limit) = connection.decompression { let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) try channel.pipeline.syncOperations.addHandler(decompressHandler) } @@ -254,13 +287,19 @@ final class HTTP2Connection { // request to it. In case of an error, we are sure that the channel was added // before. let box = ChannelBox(channel) - self.openStreams.insert(box) - channel.closeFuture.whenComplete { _ in - self.openStreams.remove(box) + connection.openStreams.insert(box) + channel.closeFuture.assumeIsolated().whenComplete { _ in + connection.openStreams.remove(box) } - channel.write(request, promise: nil) - return channel.eventLoop.makeSucceededVoidFuture() + if let streamChannelDebugInitializer = streamChannelDebugInitializer { + return streamChannelDebugInitializer(channel).map { _ in + channel.write(request, promise: nil) + } + } else { + channel.pipeline.syncOperations.write(NIOAny(request), promise: nil) + return channel.eventLoop.makeSucceededVoidFuture() + } } catch { return channel.eventLoop.makeFailedFuture(error) } @@ -322,7 +361,7 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { case .active: self.state = .active(maxStreams: maxStreams) - self.delegate.http2Connection(self, newMaxStreamSetting: maxStreams) + self.delegate.http2Connection(self.id, newMaxStreamSetting: maxStreams) case .closing, .closed: // ignore. we only wait for all connections to be closed anyway. @@ -343,7 +382,7 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { case .active: self.state = .closing - self.delegate.http2ConnectionGoAwayReceived(self) + self.delegate.http2ConnectionGoAwayReceived(self.id) case .closing, .closed: // we are already closing. Nothing new @@ -354,6 +393,9 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { func http2StreamClosed(availableStreams: Int) { self.channel.eventLoop.assertInEventLoop() - self.delegate.http2ConnectionStreamClosed(self, availableStreams: availableStreams) + self.delegate.http2ConnectionStreamClosed(self.id, availableStreams: availableStreams) } } + +@available(*, unavailable) +extension HTTP2Connection: Sendable {} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift index 0aad0c8dd..c896791cf 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift @@ -22,6 +22,7 @@ import NIOSSL import NIOTLS #if canImport(Network) +import Network import NIOTransportServices #endif @@ -47,9 +48,9 @@ extension HTTPConnectionPool { } } -protocol HTTPConnectionRequester { - func http1ConnectionCreated(_: HTTP1Connection) - func http2ConnectionCreated(_: HTTP2Connection, maximumStreams: Int) +protocol HTTPConnectionRequester: Sendable { + func http1ConnectionCreated(_: HTTP1Connection.SendableView) + func http2ConnectionCreated(_: HTTP2Connection.SendableView, maximumStreams: Int) func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Error) func waitingForConnectivity(_: HTTPConnectionPool.Connection.ID, error: Error) } @@ -73,7 +74,7 @@ extension HTTPConnectionPool.ConnectionFactory { deadline: deadline, eventLoop: eventLoop, logger: logger - ).whenComplete { result in + ).whenComplete { [logger] result in switch result { case .success(.http1_1(let channel)): do { @@ -84,7 +85,21 @@ extension HTTPConnectionPool.ConnectionFactory { decompression: self.clientConfiguration.decompression, logger: logger ) - requester.http1ConnectionCreated(connection) + + if let connectionDebugInitializer = self.clientConfiguration.http1_1ConnectionDebugInitializer { + connectionDebugInitializer(channel).hop( + to: eventLoop + ).assumeIsolated().whenComplete { debugInitializerResult in + switch debugInitializerResult { + case .success: + requester.http1ConnectionCreated(connection.sendableView) + case .failure(let error): + requester.failedToCreateHTTPConnection(connectionID, error: error) + } + } + } else { + requester.http1ConnectionCreated(connection.sendableView) + } } catch { requester.failedToCreateHTTPConnection(connectionID, error: error) } @@ -95,11 +110,34 @@ extension HTTPConnectionPool.ConnectionFactory { delegate: http2ConnectionDelegate, decompression: self.clientConfiguration.decompression, maximumConnectionUses: self.clientConfiguration.maximumUsesPerConnection, - logger: logger + logger: logger, + streamChannelDebugInitializer: + self.clientConfiguration.http2StreamChannelDebugInitializer ).whenComplete { result in switch result { case .success((let connection, let maximumStreams)): - requester.http2ConnectionCreated(connection, maximumStreams: maximumStreams) + if let connectionDebugInitializer = self.clientConfiguration.http2ConnectionDebugInitializer { + connectionDebugInitializer(channel).hop(to: eventLoop).assumeIsolated().whenComplete { + debugInitializerResult in + switch debugInitializerResult { + case .success: + requester.http2ConnectionCreated( + connection.sendableView, + maximumStreams: maximumStreams + ) + case .failure(let error): + requester.failedToCreateHTTPConnection( + connectionID, + error: error + ) + } + } + } else { + requester.http2ConnectionCreated( + connection.sendableView, + maximumStreams: maximumStreams + ) + } case .failure(let error): requester.failedToCreateHTTPConnection(connectionID, error: error) } @@ -249,15 +287,15 @@ extension HTTPConnectionPool.ConnectionFactory { // The proxyEstablishedFuture is set as soon as the HTTP1ProxyConnectHandler is in a // pipeline. It is created in HTTP1ProxyConnectHandler's handlerAdded method. - return proxyHandler.proxyEstablishedFuture!.flatMap { - channel.pipeline.removeHandler(proxyHandler).flatMap { - channel.pipeline.removeHandler(decoder).flatMap { - channel.pipeline.removeHandler(encoder) - } - } + return proxyHandler.proxyEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(proxyHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(decoder).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(encoder) + }.nonisolated() + }.nonisolated() }.flatMap { self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) - } + }.nonisolated() } } @@ -291,13 +329,13 @@ extension HTTPConnectionPool.ConnectionFactory { // The socksEstablishedFuture is set as soon as the SOCKSEventsHandler is in a // pipeline. It is created in SOCKSEventsHandler's handlerAdded method. - return socksEventHandler.socksEstablishedFuture!.flatMap { - channel.pipeline.removeHandler(socksEventHandler).flatMap { - channel.pipeline.removeHandler(socksConnectHandler) - } + return socksEventHandler.socksEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksEventHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksConnectHandler) + }.nonisolated() }.flatMap { self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) - } + }.nonisolated() } } @@ -323,7 +361,6 @@ extension HTTPConnectionPool.ConnectionFactory { case .http1Only: tlsConfig.applicationProtocols = ["http/1.1"] } - let tlsEventHandler = TLSEventsHandler(deadline: deadline) let sslServerHostname = self.key.serverNameIndicator let sslContextFuture = self.sslContextCache.sslContext( @@ -339,6 +376,7 @@ extension HTTPConnectionPool.ConnectionFactory { serverHostname: sslServerHostname ) try channel.pipeline.syncOperations.addHandler(sslHandler) + let tlsEventHandler = TLSEventsHandler(deadline: deadline) try channel.pipeline.syncOperations.addHandler(tlsEventHandler) // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a @@ -348,8 +386,14 @@ extension HTTPConnectionPool.ConnectionFactory { return channel.eventLoop.makeFailedFuture(error) } }.flatMap { negotiated -> EventLoopFuture in - channel.pipeline.removeHandler(tlsEventHandler).flatMapThrowing { - try self.matchALPNToHTTPVersion(negotiated, channel: channel) + do { + let sync = channel.pipeline.syncOperations + let context = try sync.context(handlerType: TLSEventsHandler.self) + return sync.removeHandler(context: context).flatMapThrowing { + try self.matchALPNToHTTPVersion(negotiated, channel: channel) + } + } catch { + return channel.eventLoop.makeFailedFuture(error) } } } @@ -426,9 +470,9 @@ extension HTTPConnectionPool.ConnectionFactory { // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a // pipeline. It is created in TLSEventsHandler's handlerAdded method. - return tlsEventHandler.tlsEstablishedFuture!.flatMap { negotiated in - channel.pipeline.removeHandler(tlsEventHandler).map { (channel, negotiated) } - } + return tlsEventHandler.tlsEstablishedFuture!.assumeIsolated().flatMap { negotiated in + channel.pipeline.syncOperations.removeHandler(tlsEventHandler).map { (channel, negotiated) } + }.nonisolated() } catch { assert( channel.isActive == false, @@ -468,9 +512,7 @@ extension HTTPConnectionPool.ConnectionFactory { } #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), - let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) - { + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), eventLoop is QoSEventLoop { // create NIOClientTCPBootstrap with NIOTS TLS provider let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions( on: eventLoop, @@ -478,7 +520,7 @@ extension HTTPConnectionPool.ConnectionFactory { ).map { options -> NIOClientTCPBootstrapProtocol in - tsBootstrap + NIOTSConnectionBootstrap(group: eventLoop) // validated above .channelOption( NIOTSChannelOptions.waitForActivity, value: self.clientConfiguration.networkFrameworkWaitForConnectivity @@ -515,29 +557,29 @@ extension HTTPConnectionPool.ConnectionFactory { logger: logger ) - let bootstrap = ClientBootstrap(group: eventLoop) - .connectTimeout(deadline - NIODeadline.now()) - .enableMPTCP(clientConfiguration.enableMultipath) - .channelInitializer { channel in - sslContextFuture.flatMap { sslContext -> EventLoopFuture in - do { - let sync = channel.pipeline.syncOperations - let sslHandler = try NIOSSLClientHandler( - context: sslContext, - serverHostname: self.key.serverNameIndicator - ) - let tlsEventHandler = TLSEventsHandler(deadline: deadline) + return eventLoop.submit { + ClientBootstrap(group: eventLoop) + .connectTimeout(deadline - NIODeadline.now()) + .enableMPTCP(clientConfiguration.enableMultipath) + .channelInitializer { channel in + sslContextFuture.flatMap { sslContext -> EventLoopFuture in + do { + let sync = channel.pipeline.syncOperations + let sslHandler = try NIOSSLClientHandler( + context: sslContext, + serverHostname: self.key.serverNameIndicator + ) + let tlsEventHandler = TLSEventsHandler(deadline: deadline) - try sync.addHandler(sslHandler) - try sync.addHandler(tlsEventHandler) - return channel.eventLoop.makeSucceededVoidFuture() - } catch { - return channel.eventLoop.makeFailedFuture(error) + try sync.addHandler(sslHandler) + try sync.addHandler(tlsEventHandler) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } } } - } - - return eventLoop.makeSucceededFuture(bootstrap) + } } private func matchALPNToHTTPVersion(_ negotiated: String?, channel: Channel) throws -> NegotiatedProtocol { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift index eebe4d029..251224ac0 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift @@ -324,7 +324,9 @@ final class HTTPConnectionPool: connection.executeRequest(request.req) case .executeRequests(let requests, let connection): - for request in requests { connection.executeRequest(request.req) } + for request in requests { + connection.executeRequest(request.req) + } case .failRequest(let request, let error): request.req.fail(error) @@ -459,7 +461,7 @@ final class HTTPConnectionPool: // MARK: - Protocol methods - extension HTTPConnectionPool: HTTPConnectionRequester { - func http1ConnectionCreated(_ connection: HTTP1Connection) { + func http1ConnectionCreated(_ connection: HTTP1Connection.SendableView) { self.logger.trace( "successfully created connection", metadata: [ @@ -472,7 +474,7 @@ extension HTTPConnectionPool: HTTPConnectionRequester { } } - func http2ConnectionCreated(_ connection: HTTP2Connection, maximumStreams: Int) { + func http2ConnectionCreated(_ connection: HTTP2Connection.SendableView, maximumStreams: Int) { self.logger.trace( "successfully created connection", metadata: [ @@ -514,84 +516,84 @@ extension HTTPConnectionPool: HTTPConnectionRequester { } extension HTTPConnectionPool: HTTP1ConnectionDelegate { - func http1ConnectionClosed(_ connection: HTTP1Connection) { + func http1ConnectionClosed(_ id: HTTPConnectionPool.Connection.ID) { self.logger.debug( "connection closed", metadata: [ - "ahc-connection-id": "\(connection.id)", + "ahc-connection-id": "\(id)", "ahc-http-version": "http/1.1", ] ) self.modifyStateAndRunActions { - $0.http1ConnectionClosed(connection.id) + $0.http1ConnectionClosed(id) } } - func http1ConnectionReleased(_ connection: HTTP1Connection) { + func http1ConnectionReleased(_ id: HTTPConnectionPool.Connection.ID) { self.logger.trace( "releasing connection", metadata: [ - "ahc-connection-id": "\(connection.id)", + "ahc-connection-id": "\(id)", "ahc-http-version": "http/1.1", ] ) self.modifyStateAndRunActions { - $0.http1ConnectionReleased(connection.id) + $0.http1ConnectionReleased(id) } } } extension HTTPConnectionPool: HTTP2ConnectionDelegate { - func http2Connection(_ connection: HTTP2Connection, newMaxStreamSetting: Int) { + func http2Connection(_ id: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) { self.logger.debug( "new max stream setting", metadata: [ - "ahc-connection-id": "\(connection.id)", + "ahc-connection-id": "\(id)", "ahc-http-version": "http/2", "ahc-max-streams": "\(newMaxStreamSetting)", ] ) self.modifyStateAndRunActions { - $0.newHTTP2MaxConcurrentStreamsReceived(connection.id, newMaxStreams: newMaxStreamSetting) + $0.newHTTP2MaxConcurrentStreamsReceived(id, newMaxStreams: newMaxStreamSetting) } } - func http2ConnectionGoAwayReceived(_ connection: HTTP2Connection) { + func http2ConnectionGoAwayReceived(_ id: HTTPConnectionPool.Connection.ID) { self.logger.debug( "connection go away received", metadata: [ - "ahc-connection-id": "\(connection.id)", + "ahc-connection-id": "\(id)", "ahc-http-version": "http/2", ] ) self.modifyStateAndRunActions { - $0.http2ConnectionGoAwayReceived(connection.id) + $0.http2ConnectionGoAwayReceived(id) } } - func http2ConnectionClosed(_ connection: HTTP2Connection) { + func http2ConnectionClosed(_ id: HTTPConnectionPool.Connection.ID) { self.logger.debug( "connection closed", metadata: [ - "ahc-connection-id": "\(connection.id)", + "ahc-connection-id": "\(id)", "ahc-http-version": "http/2", ] ) self.modifyStateAndRunActions { - $0.http2ConnectionClosed(connection.id) + $0.http2ConnectionClosed(id) } } - func http2ConnectionStreamClosed(_ connection: HTTP2Connection, availableStreams: Int) { + func http2ConnectionStreamClosed(_ id: HTTPConnectionPool.Connection.ID, availableStreams: Int) { self.logger.trace( "stream closed", metadata: [ - "ahc-connection-id": "\(connection.id)", + "ahc-connection-id": "\(id)", "ahc-http-version": "http/2", ] ) self.modifyStateAndRunActions { - $0.http2ConnectionStreamClosed(connection.id) + $0.http2ConnectionStreamClosed(id) } } } @@ -610,18 +612,18 @@ extension HTTPConnectionPool { typealias ID = Int private enum Reference { - case http1_1(HTTP1Connection) - case http2(HTTP2Connection) + case http1_1(HTTP1Connection.SendableView) + case http2(HTTP2Connection.SendableView) case __testOnly_connection(ID, EventLoop) } private let _ref: Reference - fileprivate static func http1_1(_ conn: HTTP1Connection) -> Self { + fileprivate static func http1_1(_ conn: HTTP1Connection.SendableView) -> Self { Connection(_ref: .http1_1(conn)) } - fileprivate static func http2(_ conn: HTTP2Connection) -> Self { + fileprivate static func http2(_ conn: HTTP2Connection.SendableView) -> Self { Connection(_ref: .http2(conn)) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index e8c07e50f..bce55eb5b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -176,7 +176,7 @@ protocol HTTPSchedulableRequest: HTTPExecutableRequest { /// A handle to the request executor. /// /// This protocol is implemented by the `HTTP1ClientChannelHandler`. -protocol HTTPRequestExecutor { +protocol HTTPRequestExecutor: Sendable { /// Writes a body part into the channel pipeline /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. @@ -201,7 +201,7 @@ protocol HTTPRequestExecutor { func cancelRequest(_ task: HTTPExecutableRequest) } -protocol HTTPExecutableRequest: AnyObject { +protocol HTTPExecutableRequest: AnyObject, Sendable { /// The request's logger var logger: Logger { get } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift index 86a54273d..71d8f15f1 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift @@ -18,6 +18,8 @@ import NIOCore import func Darwin.pow #elseif canImport(Musl) import func Musl.pow +#elseif canImport(Android) +import func Android.pow #else import func Glibc.pow #endif diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift index 6dfd4223e..0cc02cf0f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift @@ -26,7 +26,9 @@ extension HTTPConnectionPool { self.connection = connection } - static let none = Action(request: .none, connection: .none) + static var none: Action { + Action(request: .none, connection: .none) + } } enum ConnectionAction { @@ -397,7 +399,9 @@ extension HTTPConnectionPool.StateMachine { } struct EstablishedAction { - static let none: Self = .init(request: .none, connection: .none) + static var none: Self { + Self(request: .none, connection: .none) + } let request: HTTPConnectionPool.StateMachine.RequestAction let connection: EstablishedConnectionAction } diff --git a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift index 1f869506a..33a4d3cb2 100644 --- a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift +++ b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift @@ -12,30 +12,68 @@ // //===----------------------------------------------------------------------===// +import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOPosix +import struct Foundation.URL + /// Handles a streaming download to a given file path, allowing headers and progress to be reported. public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// The response type for this delegate: the total count of bytes as reported by the response - /// "Content-Length" header (if available) and the count of bytes downloaded. + /// "Content-Length" header (if available), the count of bytes downloaded, the + /// response head, and a history of requests and responses. public struct Progress: Sendable { public var totalBytes: Int? public var receivedBytes: Int + + /// The history of all requests and responses in redirect order. + public var history: [HTTPClient.RequestResponse] = [] + + /// The target URL (after redirects) of the response. + public var url: URL? { + self.history.last?.request.url + } + + public var head: HTTPResponseHead { + get { + assert(self._head != nil) + return self._head! + } + set { + self._head = newValue + } + } + + fileprivate var _head: HTTPResponseHead? = nil + + internal init(totalBytes: Int? = nil, receivedBytes: Int) { + self.totalBytes = totalBytes + self.receivedBytes = receivedBytes + } } - private var progress = Progress(totalBytes: nil, receivedBytes: 0) + private struct State { + var progress = Progress( + totalBytes: nil, + receivedBytes: 0 + ) + var fileIOThreadPool: NIOThreadPool? + var fileHandleFuture: EventLoopFuture? + var writeFuture: EventLoopFuture? + } + private let state: NIOLockedValueBox + + var _fileIOThreadPool: NIOThreadPool? { + self.state.withLockedValue { $0.fileIOThreadPool } + } public typealias Response = Progress private let filePath: String - private(set) var fileIOThreadPool: NIOThreadPool? - private let reportHead: ((HTTPClient.Task, HTTPResponseHead) -> Void)? - private let reportProgress: ((HTTPClient.Task, Progress) -> Void)? - - private var fileHandleFuture: EventLoopFuture? - private var writeFuture: EventLoopFuture? + private let reportHead: (@Sendable (HTTPClient.Task, HTTPResponseHead) -> Void)? + private let reportProgress: (@Sendable (HTTPClient.Task, Progress) -> Void)? /// Initializes a new file download delegate. /// @@ -47,20 +85,14 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public init( path: String, pool: NIOThreadPool? = nil, - reportHead: ((HTTPClient.Task, HTTPResponseHead) -> Void)? = nil, - reportProgress: ((HTTPClient.Task, Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPClient.Task, HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (HTTPClient.Task, Progress) -> Void)? = nil ) throws { - if let pool = pool { - self.fileIOThreadPool = pool - } else { - // we should use the shared thread pool from the HTTPClient which - // we will get from the `HTTPClient.Task` - self.fileIOThreadPool = nil - } - + self.state = NIOLockedValueBox(State(fileIOThreadPool: pool)) self.filePath = path self.reportHead = reportHead @@ -77,22 +109,23 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public convenience init( path: String, pool: NIOThreadPool, - reportHead: ((HTTPResponseHead) -> Void)? = nil, - reportProgress: ((Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (Progress) -> Void)? = nil ) throws { try self.init( path: path, pool: .some(pool), reportHead: reportHead.map { reportHead in - { _, head in + { @Sendable _, head in reportHead(head) } }, reportProgress: reportProgress.map { reportProgress in - { _, head in + { @Sendable _, head in reportProgress(head) } } @@ -108,39 +141,50 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public convenience init( path: String, - reportHead: ((HTTPResponseHead) -> Void)? = nil, - reportProgress: ((Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (Progress) -> Void)? = nil ) throws { try self.init( path: path, pool: nil, reportHead: reportHead.map { reportHead in - { _, head in + { @Sendable _, head in reportHead(head) } }, reportProgress: reportProgress.map { reportProgress in - { _, head in + { @Sendable _, head in reportProgress(head) } } ) } + public func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) { + self.state.withLockedValue { + $0.progress.history.append(.init(request: request, responseHead: head)) + } + } + public func didReceiveHead( task: HTTPClient.Task, _ head: HTTPResponseHead ) -> EventLoopFuture { - self.reportHead?(task, head) + self.state.withLockedValue { + $0.progress._head = head - if let totalBytesString = head.headers.first(name: "Content-Length"), - let totalBytes = Int(totalBytesString) - { - self.progress.totalBytes = totalBytes + if let totalBytesString = head.headers.first(name: "Content-Length"), + let totalBytes = Int(totalBytesString) + { + $0.progress.totalBytes = totalBytes + } } + self.reportHead?(task, head) + return task.eventLoop.makeSucceededFuture(()) } @@ -148,53 +192,90 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { - let threadPool: NIOThreadPool = { - guard let pool = self.fileIOThreadPool else { - let pool = task.fileIOThreadPool - self.fileIOThreadPool = pool + let (progress, io) = self.state.withLockedValue { state in + let threadPool: NIOThreadPool = { + guard let pool = state.fileIOThreadPool else { + let pool = task.fileIOThreadPool + state.fileIOThreadPool = pool + return pool + } return pool + }() + + let io = NonBlockingFileIO(threadPool: threadPool) + state.progress.receivedBytes += buffer.readableBytes + return (state.progress, io) + } + self.reportProgress?(task, progress) + + let writeFuture = self.state.withLockedValue { state in + let writeFuture: EventLoopFuture + if let fileHandleFuture = state.fileHandleFuture { + writeFuture = fileHandleFuture.flatMap { + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + } + } else { + let fileHandleFuture = io.openFile( + _deprecatedPath: self.filePath, + mode: .write, + flags: .allowFileCreation(), + eventLoop: task.eventLoop + ) + state.fileHandleFuture = fileHandleFuture + writeFuture = fileHandleFuture.flatMap { + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + } } - return pool - }() - let io = NonBlockingFileIO(threadPool: threadPool) - self.progress.receivedBytes += buffer.readableBytes - self.reportProgress?(task, self.progress) - - let writeFuture: EventLoopFuture - if let fileHandleFuture = self.fileHandleFuture { - writeFuture = fileHandleFuture.flatMap { - io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) - } - } else { - let fileHandleFuture = io.openFile( - path: self.filePath, - mode: .write, - flags: .allowFileCreation(), - eventLoop: task.eventLoop - ) - self.fileHandleFuture = fileHandleFuture - writeFuture = fileHandleFuture.flatMap { - io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) - } + + state.writeFuture = writeFuture + return writeFuture } - self.writeFuture = writeFuture return writeFuture } private func close(fileHandle: NIOFileHandle) { try! fileHandle.close() - self.fileHandleFuture = nil + self.state.withLockedValue { + $0.fileHandleFuture = nil + } } private func finalize() { - if let writeFuture = self.writeFuture { - writeFuture.whenComplete { _ in - self.fileHandleFuture?.whenSuccess(self.close(fileHandle:)) - self.writeFuture = nil + enum Finalize { + case writeFuture(EventLoopFuture) + case fileHandleFuture(EventLoopFuture) + case none + } + + let finalize: Finalize = self.state.withLockedValue { state in + if let writeFuture = state.writeFuture { + return .writeFuture(writeFuture) + } else if let fileHandleFuture = state.fileHandleFuture { + return .fileHandleFuture(fileHandleFuture) + } else { + return .none + } + } + + switch finalize { + case .writeFuture(let future): + future.whenComplete { _ in + let fileHandleFuture = self.state.withLockedValue { state in + let future = state.fileHandleFuture + state.fileHandleFuture = nil + state.writeFuture = nil + return future + } + + fileHandleFuture?.whenSuccess { + self.close(fileHandle: $0) + } } - } else { - self.fileHandleFuture?.whenSuccess(self.close(fileHandle:)) + case .fileHandleFuture(let future): + future.whenSuccess { self.close(fileHandle: $0) } + case .none: + () } } @@ -204,6 +285,6 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { public func didFinishRequest(task: HTTPClient.Task) throws -> Response { self.finalize() - return self.progress + return self.state.withLockedValue { $0.progress } } } diff --git a/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift b/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift index 847a99af2..759f6728a 100644 --- a/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift +++ b/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift @@ -26,6 +26,8 @@ import locale_h import Darwin #elseif canImport(Musl) import Musl +#elseif canImport(Android) +import Android #elseif canImport(Glibc) import Glibc #endif @@ -214,7 +216,7 @@ extension String.UTF8View.SubSequence { } } -private let posixLocale: UnsafeMutableRawPointer = { +nonisolated(unsafe) private let posixLocale: UnsafeMutableRawPointer = { // All POSIX systems must provide a "POSIX" locale, and its date/time formats are US English. // https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/V1_chap07.html#tag_07_03_05 let _posixLocale = newlocale(LC_TIME_MASK | LC_NUMERIC_MASK, "POSIX", nil)! diff --git a/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift b/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift new file mode 100644 index 000000000..f7d471f10 --- /dev/null +++ b/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClient { + #if compiler(>=6.0) + /// Start & automatically shut down a new ``HTTPClient``. + /// + /// This method allows to start & automatically dispose of a ``HTTPClient`` following the principle of Structured Concurrency. + /// The ``HTTPClient`` is guaranteed to be shut down upon return, whether `body` throws or not. + /// + /// This may be particularly useful if you cannot use the shared singleton (``HTTPClient/shared``). + public static func withHTTPClient( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger? = nil, + isolation: isolated (any Actor)? = #isolation, + _ body: (HTTPClient) async throws -> Return + ) async throws -> Return { + let logger = (backgroundActivityLogger ?? HTTPClient.loggingDisabled) + let httpClient = HTTPClient( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: logger + ) + return try await asyncDo { + try await body(httpClient) + } finally: { _ in + try await httpClient.shutdown() + } + } + #else + /// Start & automatically shut down a new ``HTTPClient``. + /// + /// This method allows to start & automatically dispose of a ``HTTPClient`` following the principle of Structured Concurrency. + /// The ``HTTPClient`` is guaranteed to be shut down upon return, whether `body` throws or not. + /// + /// This may be particularly useful if you cannot use the shared singleton (``HTTPClient/shared``). + public static func withHTTPClient( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger? = nil, + _ body: (HTTPClient) async throws -> Return + ) async throws -> Return { + let logger = (backgroundActivityLogger ?? HTTPClient.loggingDisabled) + let httpClient = HTTPClient( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: logger + ) + return try await asyncDo { + try await body(httpClient) + } finally: { _ in + try await httpClient.shutdown() + } + } + #endif +} diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index f1655c7c5..e628c6073 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -222,23 +222,50 @@ public class HTTPClient { """ ) } - let errorStorage: NIOLockedValueBox = NIOLockedValueBox(nil) - let continuation = DispatchWorkItem {} - self.shutdown(requiresCleanClose: requiresCleanClose, queue: DispatchQueue(label: "async-http-client.shutdown")) - { error in - if let error = error { - errorStorage.withLockedValue { errorStorage in - errorStorage = error + + final class ShutdownError: @unchecked Sendable { + // @unchecked because error is protected by lock. + + // Stores whether the shutdown has happened or not. + private let lock: ConditionLock + private var error: Error? + + init() { + self.error = nil + self.lock = ConditionLock(value: false) + } + + func didShutdown(_ error: (any Error)?) { + self.lock.lock(whenValue: false) + defer { + self.lock.unlock(withValue: true) } + self.error = error } - continuation.perform() - } - continuation.wait() - try errorStorage.withLockedValue { errorStorage in - if let error = errorStorage { - throw error + + func blockUntilShutdown() -> (any Error)? { + self.lock.lock(whenValue: true) + defer { + self.lock.unlock(withValue: true) + } + return self.error } } + + let shutdownError = ShutdownError() + + self.shutdown( + requiresCleanClose: requiresCleanClose, + queue: DispatchQueue(label: "async-http-client.shutdown") + ) { error in + shutdownError.didShutdown(error) + } + + let error = shutdownError.blockUntilShutdown() + + if let error = error { + throw error + } } /// Shuts down the client and event loop gracefully. @@ -311,6 +338,7 @@ public class HTTPClient { } } + @Sendable private func makeOrGetFileIOThreadPool() -> NIOThreadPool { self.fileIOThreadPoolLock.withLock { guard let fileIOThreadPool = self.fileIOThreadPool else { @@ -756,20 +784,20 @@ public class HTTPClient { delegate: delegate ) - var deadlineSchedule: Scheduled? if let deadline = deadline { - deadlineSchedule = taskEL.scheduleTask(deadline: deadline) { + let deadlineSchedule = taskEL.scheduleTask(deadline: deadline) { requestBag.deadlineExceeded() } task.promise.futureResult.whenComplete { _ in - deadlineSchedule?.cancel() + deadlineSchedule.cancel() } } self.poolManager.executeRequest(requestBag) } catch { - task.fail(with: error, delegateType: Delegate.self) + delegate.didReceiveError(task: task, error) + task.failInternal(with: error) } return task @@ -847,6 +875,15 @@ public class HTTPClient { /// By default, don't use it public var enableMultipath: Bool + /// A method with access to the HTTP/1 connection channel that is called when creating the connection. + public var http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// A method with access to the HTTP/2 connection channel that is called when creating the connection. + public var http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// A method with access to the HTTP/2 stream channel that is called when creating the stream. + public var http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + public init( tlsConfiguration: TLSConfiguration? = nil, redirectConfiguration: RedirectConfiguration? = nil, @@ -949,6 +986,32 @@ public class HTTPClient { decompression: decompression ) } + + public init( + tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: ConnectionPool = ConnectionPool(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled, + http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) { + self.init( + tlsConfiguration: tlsConfiguration, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: connectionPool, + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) + self.http1_1ConnectionDebugInitializer = http1_1ConnectionDebugInitializer + self.http2ConnectionDebugInitializer = http2ConnectionDebugInitializer + self.http2StreamChannelDebugInitializer = http2StreamChannelDebugInitializer + } } /// Specifies how `EventLoopGroup` will be created and establishes lifecycle ownership. diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 0f061fbe6..8d92d8ef7 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -13,7 +13,6 @@ //===----------------------------------------------------------------------===// import Algorithms -import Foundation import Logging import NIOConcurrencyHelpers import NIOCore @@ -21,20 +20,27 @@ import NIOHTTP1 import NIOPosix import NIOSSL +#if compiler(>=6.0) +import Foundation +#else +@preconcurrency import Foundation +#endif + extension HTTPClient { /// A request body. - public struct Body { + public struct Body: Sendable { /// A streaming uploader. /// /// ``StreamWriter`` abstracts - public struct StreamWriter { - let closure: (IOData) -> EventLoopFuture + public struct StreamWriter: Sendable { + let closure: @Sendable (IOData) -> EventLoopFuture /// Create new ``HTTPClient/Body/StreamWriter`` /// /// - parameters: /// - closure: function that will be called to write actual bytes to the channel. - public init(closure: @escaping (IOData) -> EventLoopFuture) { + @preconcurrency + public init(closure: @escaping @Sendable (IOData) -> EventLoopFuture) { self.closure = closure } @@ -50,8 +56,8 @@ extension HTTPClient { func writeChunks( of bytes: Bytes, maxChunkSize: Int - ) -> EventLoopFuture where Bytes.Element == UInt8 { - // `StreamWriter` is has design issues, for example + ) -> EventLoopFuture where Bytes.Element == UInt8, Bytes: Sendable { + // `StreamWriter` has design issues, for example // - https://github.com/swift-server/async-http-client/issues/194 // - https://github.com/swift-server/async-http-client/issues/264 // - We're not told the EventLoop the task runs on and the user is free to return whatever EL they @@ -61,49 +67,52 @@ extension HTTPClient { typealias Iterator = EnumeratedSequence>.Iterator typealias Chunk = (offset: Int, element: ChunksOfCountCollection.Element) - func makeIteratorAndFirstChunk( - bytes: Bytes - ) -> ( - iterator: NIOLockedValueBox, - chunk: Chunk - )? { - var iterator = bytes.chunks(ofCount: maxChunkSize).enumerated().makeIterator() - guard let chunk = iterator.next() else { - return nil + // HACK (again, we're not told the right EventLoop): Let's write 0 bytes to make the user tell us... + return self.write(.byteBuffer(ByteBuffer())).flatMapWithEventLoop { (_, loop) in + func makeIteratorAndFirstChunk( + bytes: Bytes + ) -> (iterator: Iterator, chunk: Chunk)? { + var iterator = bytes.chunks(ofCount: maxChunkSize).enumerated().makeIterator() + guard let chunk = iterator.next() else { + return nil + } + + return (iterator, chunk) } - return (NIOLockedValueBox(iterator), chunk) - } - - guard let (iterator, chunk) = makeIteratorAndFirstChunk(bytes: bytes) else { - return self.write(IOData.byteBuffer(.init())) - } + guard let iteratorAndChunk = makeIteratorAndFirstChunk(bytes: bytes) else { + return loop.makeSucceededVoidFuture() + } - @Sendable // can't use closure here as we recursively call ourselves which closures can't do - func writeNextChunk(_ chunk: Chunk, allDone: EventLoopPromise) { - if let nextElement = iterator.withLockedValue({ $0.next() }) { - self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).map { - let index = nextElement.offset - if (index + 1) % 4 == 0 { - // Let's not stack-overflow if the futures insta-complete which they at least in HTTP/2 - // mode. - // Also, we must frequently return to the EventLoop because we may get the pause signal - // from another thread. If we fail to do that promptly, we may balloon our body chunks - // into memory. - allDone.futureResult.eventLoop.execute { - writeNextChunk(nextElement, allDone: allDone) + var iterator = iteratorAndChunk.0 + let chunk = iteratorAndChunk.1 + + // can't use closure here as we recursively call ourselves which closures can't do + func writeNextChunk(_ chunk: Chunk, allDone: EventLoopPromise) { + let loop = allDone.futureResult.eventLoop + loop.assertInEventLoop() + + if let (index, element) = iterator.next() { + self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).hop(to: loop).assumeIsolated().map + { + if (index + 1) % 4 == 0 { + // Let's not stack-overflow if the futures insta-complete which they at least in HTTP/2 + // mode. + // Also, we must frequently return to the EventLoop because we may get the pause signal + // from another thread. If we fail to do that promptly, we may balloon our body chunks + // into memory. + allDone.futureResult.eventLoop.assumeIsolated().execute { + writeNextChunk((offset: index, element: element), allDone: allDone) + } + } else { + writeNextChunk((offset: index, element: element), allDone: allDone) } - } else { - writeNextChunk(nextElement, allDone: allDone) - } - }.cascadeFailure(to: allDone) - } else { - self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).cascade(to: allDone) + }.nonisolated().cascadeFailure(to: allDone) + } else { + self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).cascade(to: allDone) + } } - } - // HACK (again, we're not told the right EventLoop): Let's write 0 bytes to make the user tell us... - return self.write(.byteBuffer(ByteBuffer())).flatMapWithEventLoop { (_, loop) in let allDone = loop.makePromise(of: Void.self) writeNextChunk(chunk, allDone: allDone) return allDone.futureResult @@ -209,7 +218,7 @@ extension HTTPClient { } /// Represents an HTTP request. - public struct Request { + public struct Request: Sendable { /// Request HTTP method, defaults to `GET`. public let method: HTTPMethod /// Remote URL. @@ -377,6 +386,13 @@ extension HTTPClient { public var headers: HTTPHeaders /// Response body. public var body: ByteBuffer? + /// The history of all requests and responses in redirect order. + public var history: [RequestResponse] + + /// The target URL (after redirects) of the response. + public var url: URL? { + self.history.last?.request.url + } /// Create HTTP `Response`. /// @@ -392,6 +408,7 @@ extension HTTPClient { self.version = HTTPVersion(major: 1, minor: 1) self.headers = headers self.body = body + self.history = [] } /// Create HTTP `Response`. @@ -414,6 +431,32 @@ extension HTTPClient { self.version = version self.headers = headers self.body = body + self.history = [] + } + + /// Create HTTP `Response`. + /// + /// - parameters: + /// - host: Remote host of the request. + /// - status: Response HTTP status. + /// - version: Response HTTP version. + /// - headers: Reponse HTTP headers. + /// - body: Response body. + /// - history: History of all requests and responses in redirect order. + public init( + host: String, + status: HTTPResponseStatus, + version: HTTPVersion, + headers: HTTPHeaders, + body: ByteBuffer?, + history: [RequestResponse] + ) { + self.host = host + self.status = status + self.version = version + self.headers = headers + self.body = body + self.history = history } } @@ -457,6 +500,16 @@ extension HTTPClient { } } } + + public struct RequestResponse: Sendable { + public var request: Request + public var responseHead: HTTPResponseHead + + public init(request: Request, responseHead: HTTPResponseHead) { + self.request = request + self.responseHead = responseHead + } + } } /// The default ``HTTPClientResponseDelegate``. @@ -485,7 +538,12 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { } } - var state = State.idle + private struct MutableState: Sendable { + var history = [HTTPClient.RequestResponse]() + var state = State.idle + } + + private let state: NIOLockedValueBox let requestMethod: HTTPMethod let requestHost: String @@ -519,97 +577,126 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { self.requestMethod = request.method self.requestHost = request.host self.maxBodySize = maxBodySize + self.state = NIOLockedValueBox(MutableState()) + } + + public func didVisitURL( + task: HTTPClient.Task, + _ request: HTTPClient.Request, + _ head: HTTPResponseHead + ) { + self.state.withLockedValue { + $0.history.append(.init(request: request, responseHead: head)) + } } public func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - switch self.state { - case .idle: - if self.requestMethod != .HEAD, - let contentLength = head.headers.first(name: "Content-Length"), - let announcedBodySize = Int(contentLength), - announcedBodySize > self.maxBodySize - { - let error = ResponseTooBigError(maxBodySize: maxBodySize) - self.state = .error(error) - return task.eventLoop.makeFailedFuture(error) - } + let responseTooBig: Bool + + if self.requestMethod != .HEAD, + let contentLength = head.headers.first(name: "Content-Length"), + let announcedBodySize = Int(contentLength), + announcedBodySize > self.maxBodySize + { + responseTooBig = true + } else { + responseTooBig = false + } - self.state = .head(head) - case .head: - preconditionFailure("head already set") - case .body: - preconditionFailure("no head received before body") - case .end: - preconditionFailure("request already processed") - case .error: - break - } - return task.eventLoop.makeSucceededFuture(()) + return self.state.withLockedValue { + switch $0.state { + case .idle: + if responseTooBig { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + + $0.state = .head(head) + case .head: + preconditionFailure("head already set") + case .body: + preconditionFailure("no head received before body") + case .end: + preconditionFailure("request already processed") + case .error: + break + } + return task.eventLoop.makeSucceededFuture(()) + } } public func didReceiveBodyPart(task: HTTPClient.Task, _ part: ByteBuffer) -> EventLoopFuture { - switch self.state { - case .idle: - preconditionFailure("no head received before body") - case .head(let head): - guard part.readableBytes <= self.maxBodySize else { - let error = ResponseTooBigError(maxBodySize: self.maxBodySize) - self.state = .error(error) - return task.eventLoop.makeFailedFuture(error) - } - self.state = .body(head, part) - case .body(let head, var body): - let newBufferSize = body.writerIndex + part.readableBytes - guard newBufferSize <= self.maxBodySize else { - let error = ResponseTooBigError(maxBodySize: self.maxBodySize) - self.state = .error(error) - return task.eventLoop.makeFailedFuture(error) - } + self.state.withLockedValue { + switch $0.state { + case .idle: + preconditionFailure("no head received before body") + case .head(let head): + guard part.readableBytes <= self.maxBodySize else { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + $0.state = .body(head, part) + case .body(let head, var body): + let newBufferSize = body.writerIndex + part.readableBytes + guard newBufferSize <= self.maxBodySize else { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } - // The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's - // a cross-module call in the way) so we need to drop the original reference to `body` in - // `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which - // has no associated data). We'll fix it at the bottom of this block. - self.state = .end - var part = part - body.writeBuffer(&part) - self.state = .body(head, body) - case .end: - preconditionFailure("request already processed") - case .error: - break - } - return task.eventLoop.makeSucceededFuture(()) + // The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's + // a cross-module call in the way) so we need to drop the original reference to `body` in + // `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which + // has no associated data). We'll fix it at the bottom of this block. + $0.state = .end + var part = part + body.writeBuffer(&part) + $0.state = .body(head, body) + case .end: + preconditionFailure("request already processed") + case .error: + break + } + return task.eventLoop.makeSucceededFuture(()) + } } public func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.state = .error(error) + self.state.withLockedValue { + $0.state = .error(error) + } } public func didFinishRequest(task: HTTPClient.Task) throws -> Response { - switch self.state { - case .idle: - preconditionFailure("no head received before end") - case .head(let head): - return Response( - host: self.requestHost, - status: head.status, - version: head.version, - headers: head.headers, - body: nil - ) - case .body(let head, let body): - return Response( - host: self.requestHost, - status: head.status, - version: head.version, - headers: head.headers, - body: body - ) - case .end: - preconditionFailure("request already processed") - case .error(let error): - throw error + try self.state.withLockedValue { + switch $0.state { + case .idle: + preconditionFailure("no head received before end") + case .head(let head): + return Response( + host: self.requestHost, + status: head.status, + version: head.version, + headers: head.headers, + body: nil, + history: $0.history + ) + case .body(let head, let body): + return Response( + host: self.requestHost, + status: head.status, + version: head.version, + headers: head.headers, + body: body, + history: $0.history + ) + case .end: + preconditionFailure("request already processed") + case .error(let error): + throw error + } } } } @@ -645,8 +732,9 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { /// released together with the `HTTPTaskHandler` when channel is closed. /// Users of the library are not required to keep a reference to the /// object that implements this protocol, but may do so if needed. -public protocol HTTPClientResponseDelegate: AnyObject { - associatedtype Response +@preconcurrency +public protocol HTTPClientResponseDelegate: AnyObject, Sendable { + associatedtype Response: Sendable /// Called when the request head is sent. Will be called once. /// @@ -668,7 +756,16 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// - task: Current request context. func didSendRequest(task: HTTPClient.Task) - /// Called when response head is received. Will be called once. + /// Called each time a response head is received (including redirects), and always called before ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``. + /// You can use this method to keep an entire history of the request/response chain. + /// + /// - parameters: + /// - task: Current request context. + /// - request: The request that was sent. + /// - head: Received response head. + func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) + + /// Called when the final response head is received (after redirects). /// You must return an `EventLoopFuture` that you complete when you have finished processing the body part. /// You can create an already succeeded future by calling `task.eventLoop.makeSucceededFuture(())`. /// @@ -734,6 +831,11 @@ extension HTTPClientResponseDelegate { /// By default, this does nothing. public func didSendRequest(task: HTTPClient.Task) {} + /// Default implementation of ``HTTPClientResponseDelegate/didVisitURL(task:_:_:)-2el9y``. + /// + /// By default, this does nothing. + public func didVisitURL(task: HTTPClient.Task, _: HTTPClient.Request, _: HTTPResponseHead) {} + /// Default implementation of ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``. /// /// By default, this does nothing. @@ -807,7 +909,7 @@ extension URL { } } -protocol HTTPClientTaskDelegate { +protocol HTTPClientTaskDelegate: Sendable { func fail(_ error: Error) } @@ -816,7 +918,7 @@ extension HTTPClient { /// /// Will be created by the library and could be used for obtaining /// `EventLoopFuture` of the execution or cancellation of the execution. - public final class Task { + public final class Task: Sendable { /// The `EventLoop` the delegate will be executed on. public let eventLoop: EventLoop /// The `Logger` used by the `Task` for logging. @@ -824,41 +926,46 @@ extension HTTPClient { let promise: EventLoopPromise + struct State: Sendable { + var isCancelled: Bool + var taskDelegate: HTTPClientTaskDelegate? + } + + private let state: NIOLockedValueBox + var isCancelled: Bool { - self.lock.withLock { self._isCancelled } + self.state.withLockedValue { $0.isCancelled } } var taskDelegate: HTTPClientTaskDelegate? { get { - self.lock.withLock { self._taskDelegate } + self.state.withLockedValue { $0.taskDelegate } } set { - self.lock.withLock { self._taskDelegate = newValue } + self.state.withLockedValue { $0.taskDelegate = newValue } } } - private var _isCancelled: Bool = false - private var _taskDelegate: HTTPClientTaskDelegate? - private let lock = NIOLock() - private let makeOrGetFileIOThreadPool: () -> NIOThreadPool + private let makeOrGetFileIOThreadPool: @Sendable () -> NIOThreadPool /// The shared thread pool of a ``HTTPClient`` used for file IO. It is lazily created on first access. internal var fileIOThreadPool: NIOThreadPool { self.makeOrGetFileIOThreadPool() } - init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool) { + init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping @Sendable () -> NIOThreadPool) { self.eventLoop = eventLoop self.promise = eventLoop.makePromise() self.logger = logger self.makeOrGetFileIOThreadPool = makeOrGetFileIOThreadPool + self.state = NIOLockedValueBox(State(isCancelled: false, taskDelegate: nil)) } static func failedTask( eventLoop: EventLoop, error: Error, logger: Logger, - makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool + makeOrGetFileIOThreadPool: @escaping @Sendable () -> NIOThreadPool ) -> Task { let task = self.init( eventLoop: eventLoop, @@ -879,60 +986,58 @@ extension HTTPClient { /// - returns: The value of ``futureResult`` when it completes. /// - throws: The error value of ``futureResult`` if it errors. @available(*, noasync, message: "wait() can block indefinitely, prefer get()", renamed: "get()") - public func wait() throws -> Response { + @preconcurrency + public func wait() throws -> Response where Response: Sendable { try self.promise.futureResult.wait() } /// Provides the result of this request. /// + /// - warning: This method may violates Structured Concurrency because doesn't respect cancellation. + /// /// - returns: The value of ``futureResult`` when it completes. /// - throws: The error value of ``futureResult`` if it errors. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) - public func get() async throws -> Response { + @preconcurrency + public func get() async throws -> Response where Response: Sendable { try await self.promise.futureResult.get() } - /// Cancels the request execution. + /// Initiate cancellation of a HTTP request. + /// + /// This method will return immeidately and doesn't wait for the cancellation to complete. public func cancel() { self.fail(reason: HTTPClientError.cancelled) } - /// Cancels the request execution with a custom `Error`. + /// Initiate cancellation of a HTTP request with an `error`. + /// + /// This method will return immeidately and doesn't wait for the cancellation to complete. + /// /// - Parameter error: the error that is used to fail the promise public func fail(reason error: Error) { - let taskDelegate = self.lock.withLock { () -> HTTPClientTaskDelegate? in - self._isCancelled = true - return self._taskDelegate + let taskDelegate = self.state.withLockedValue { state in + state.isCancelled = true + return state.taskDelegate } taskDelegate?.fail(error) } - func succeed( - promise: EventLoopPromise?, - with value: Response, - delegateType: Delegate.Type, - closing: Bool - ) { - promise?.succeed(value) - } - - func fail( - with error: Error, - delegateType: Delegate.Type + /// Called internally only, used to fail a task from within the state machine functionality. + func failInternal( + with error: Error ) { self.promise.fail(error) } } } -extension HTTPClient.Task: @unchecked Sendable {} - internal struct TaskCancelEvent {} // MARK: - RedirectHandler -internal struct RedirectHandler { +internal struct RedirectHandler { let request: HTTPClient.Request let redirectState: RedirectState let execute: (HTTPClient.Request, RedirectState) -> HTTPClient.Task @@ -949,7 +1054,7 @@ internal struct RedirectHandler { status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise - ) { + ) -> HTTPClient.Task? { do { var redirectState = self.redirectState try redirectState.redirect(to: redirectURL.absoluteString) @@ -969,13 +1074,19 @@ internal struct RedirectHandler { headers: headers, body: body ) - self.execute(newRequest, redirectState).futureResult.whenComplete { result in + + let newTask = self.execute(newRequest, redirectState) + + newTask.futureResult.whenComplete { result in promise.futureResult.eventLoop.execute { promise.completeWith(result) } } + + return newTask } catch { promise.fail(error) + return nil } } } diff --git a/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift b/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift new file mode 100644 index 000000000..b25a0f00d --- /dev/null +++ b/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore + +extension NIOLoopBound { + @inlinable + func execute(_ body: @Sendable @escaping (Value) -> Void) { + if self.eventLoop.inEventLoop { + body(self.value) + } else { + self.eventLoop.execute { + body(self.value) + } + } + } +} diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift b/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift index ef505e3b7..e8278e095 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift @@ -60,7 +60,7 @@ extension TLSVersion { @available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) extension TLSConfiguration { /// Dispatch queue used by Network framework TLS to control certificate verification - static var tlsDispatchQueue = DispatchQueue(label: "TLSDispatch") + static let tlsDispatchQueue = DispatchQueue(label: "TLSDispatch") /// create NWProtocolTLS.Options for use with NIOTransportServices from the NIOSSL TLSConfiguration /// diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index f2720d9ef..f206325ee 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -18,7 +18,8 @@ import NIOCore import NIOHTTP1 import NIOSSL -final class RequestBag { +@preconcurrency +final class RequestBag: Sendable { /// Defends against the call stack getting too large when consuming body parts. /// /// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users @@ -35,13 +36,23 @@ final class RequestBag { } private let delegate: Delegate - private var request: HTTPClient.Request - // the request state is synchronized on the task eventLoop - private var state: StateMachine + struct LoopBoundState: @unchecked Sendable { + // The 'StateMachine' *isn't* Sendable (it holds various objects which aren't). This type + // needs to be sendable so that we can construct a loop bound box off of the event loop + // to hold this state and then subsequently only access it from the event loop. This needs + // to happen so that the request bag can be constructed off of the event loop. If it's + // constructed on the event loop then there's a timing window between users issuing + // a request and calling shutdown where the underlying pool doesn't know about the request + // so the shutdown call may cancel it. + var request: HTTPClient.Request + var state: StateMachine + var consumeBodyPartStackDepth: Int + // if a redirect occurs, we store the task for it so we can propagate cancellation + var redirectTask: HTTPClient.Task? = nil + } - // the consume body part stack depth is synchronized on the task event loop. - private var consumeBodyPartStackDepth: Int + private let loopBoundState: NIOLoopBoundBox // MARK: HTTPClientTask properties @@ -58,6 +69,8 @@ final class RequestBag { let eventLoopPreference: HTTPClient.EventLoopPreference + let tlsConfiguration: TLSConfiguration? + init( request: HTTPClient.Request, eventLoopPreference: HTTPClient.EventLoopPreference, @@ -70,9 +83,13 @@ final class RequestBag { self.poolKey = .init(request, dnsOverride: requestOptions.dnsOverride) self.eventLoopPreference = eventLoopPreference self.task = task - self.state = .init(redirectHandler: redirectHandler) - self.consumeBodyPartStackDepth = 0 - self.request = request + + let loopBoundState = LoopBoundState( + request: request, + state: StateMachine(redirectHandler: redirectHandler), + consumeBodyPartStackDepth: 0 + ) + self.loopBoundState = NIOLoopBoundBox.makeBoxSendingValue(loopBoundState, eventLoop: task.eventLoop) self.connectionDeadline = connectionDeadline self.requestOptions = requestOptions self.delegate = delegate @@ -81,6 +98,8 @@ final class RequestBag { self.requestHead = head self.requestFramingMetadata = metadata + self.tlsConfiguration = request.tlsConfiguration + self.task.taskDelegate = self self.task.futureResult.whenComplete { _ in self.task.taskDelegate = nil @@ -89,22 +108,19 @@ final class RequestBag { private func requestWasQueued0(_ scheduler: HTTPRequestScheduler) { self.logger.debug("Request was queued (waiting for a connection to become available)") - - self.task.eventLoop.assertInEventLoop() - self.state.requestWasQueued(scheduler) + self.loopBoundState.value.state.requestWasQueued(scheduler) } // MARK: - Request - private func willExecuteRequest0(_ executor: HTTPRequestExecutor) { - self.task.eventLoop.assertInEventLoop() - let action = self.state.willExecuteRequest(executor) + let action = self.loopBoundState.value.state.willExecuteRequest(executor) switch action { case .cancelExecuter(let executor): executor.cancelRequest(self) case .failTaskAndCancelExecutor(let error, let executor): self.delegate.didReceiveError(task: self.task, error) - self.task.fail(with: error, delegateType: Delegate.self) + self.task.failInternal(with: error) executor.cancelRequest(self) case .none: break @@ -112,26 +128,22 @@ final class RequestBag { } private func requestHeadSent0() { - self.task.eventLoop.assertInEventLoop() - self.delegate.didSendRequestHead(task: self.task, self.requestHead) - if self.request.body == nil { + if self.loopBoundState.value.request.body == nil { self.delegate.didSendRequest(task: self.task) } } private func resumeRequestBodyStream0() { - self.task.eventLoop.assertInEventLoop() - - let produceAction = self.state.resumeRequestBodyStream() + let produceAction = self.loopBoundState.value.state.resumeRequestBodyStream() switch produceAction { case .startWriter: - guard let body = self.request.body else { + guard let body = self.loopBoundState.value.request.body else { preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream") } - self.request.body = nil + self.loopBoundState.value.request.body = nil let writer = HTTPClient.Body.StreamWriter { self.writeNextRequestPart($0) @@ -150,9 +162,7 @@ final class RequestBag { } private func pauseRequestBodyStream0() { - self.task.eventLoop.assertInEventLoop() - - self.state.pauseRequestBodyStream() + self.loopBoundState.value.state.pauseRequestBodyStream() } private func writeNextRequestPart(_ part: IOData) -> EventLoopFuture { @@ -166,14 +176,12 @@ final class RequestBag { } private func writeNextRequestPart0(_ part: IOData) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() - - let action = self.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) + let action = self.loopBoundState.value.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) switch action { case .failTask(let error): self.delegate.didReceiveError(task: self.task, error) - self.task.fail(with: error, delegateType: Delegate.self) + self.task.failInternal(with: error) return self.task.eventLoop.makeFailedFuture(error) case .failFuture(let error): @@ -190,9 +198,7 @@ final class RequestBag { } private func finishRequestBodyStream(_ result: Result) { - self.task.eventLoop.assertInEventLoop() - - let action = self.state.finishRequestBodyStream(result) + let action = self.loopBoundState.value.state.finishRequestBodyStream(result) switch action { case .none: @@ -223,10 +229,10 @@ final class RequestBag { // MARK: - Response - private func receiveResponseHead0(_ head: HTTPResponseHead) { - self.task.eventLoop.assertInEventLoop() + self.delegate.didVisitURL(task: self.task, self.loopBoundState.value.request, head) // runs most likely on channel eventLoop - switch self.state.receiveResponseHead(head) { + switch self.loopBoundState.value.state.receiveResponseHead(head) { case .none: break @@ -234,7 +240,11 @@ final class RequestBag { executor.demandResponseBodyStream(self) case .redirect(let executor, let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) executor.cancelRequest(self) case .forwardResponseHead(let head): @@ -248,9 +258,7 @@ final class RequestBag { } private func receiveResponseBodyParts0(_ buffer: CircularBuffer) { - self.task.eventLoop.assertInEventLoop() - - switch self.state.receiveResponseBodyParts(buffer) { + switch self.loopBoundState.value.state.receiveResponseBodyParts(buffer) { case .none: break @@ -258,7 +266,11 @@ final class RequestBag { executor.demandResponseBodyStream(self) case .redirect(let executor, let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) executor.cancelRequest(self) case .forwardResponsePart(let part): @@ -272,8 +284,7 @@ final class RequestBag { } private func succeedRequest0(_ buffer: CircularBuffer?) { - self.task.eventLoop.assertInEventLoop() - let action = self.state.succeedRequest(buffer) + let action = self.loopBoundState.value.state.succeedRequest(buffer) switch action { case .none: @@ -294,13 +305,15 @@ final class RequestBag { } case .redirect(let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) } } private func consumeMoreBodyData0(resultOfPreviousConsume result: Result) { - self.task.eventLoop.assertInEventLoop() - // We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart` // future to be returned to us completed. If it is, we will recurse back into this method. To // break that recursion we have a max stack depth which we increment and decrement in this method: @@ -311,24 +324,27 @@ final class RequestBag { // that risk ending up in this loop. That's because we don't need an accurate count: our limit is // a best-effort target anyway, one stack frame here or there does not put us at risk. We're just // trying to prevent ourselves looping out of control. - self.consumeBodyPartStackDepth += 1 + self.loopBoundState.value.consumeBodyPartStackDepth += 1 defer { - self.consumeBodyPartStackDepth -= 1 - assert(self.consumeBodyPartStackDepth >= 0) + self.loopBoundState.value.consumeBodyPartStackDepth -= 1 + assert(self.loopBoundState.value.consumeBodyPartStackDepth >= 0) } - let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result) + let consumptionAction = self.loopBoundState.value.state.consumeMoreBodyData( + resultOfPreviousConsume: result + ) switch consumptionAction { case .consume(let byteBuffer): self.delegate.didReceiveBodyPart(task: self.task, byteBuffer) .hop(to: self.task.eventLoop) + .assumeIsolated() .whenComplete { result in - if self.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { + if self.loopBoundState.value.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { self.consumeMoreBodyData0(resultOfPreviousConsume: result) } else { // We need to unwind the stack, let's take a break. - self.task.eventLoop.execute { + self.task.eventLoop.assumeIsolated().execute { self.consumeMoreBodyData0(resultOfPreviousConsume: result) } } @@ -339,7 +355,7 @@ final class RequestBag { case .finishStream: do { let response = try self.delegate.didFinishRequest(task: self.task) - self.task.promise.succeed(response) + self.task.promise.assumeIsolated().succeed(response) } catch { self.task.promise.fail(error) } @@ -353,11 +369,11 @@ final class RequestBag { } private func fail0(_ error: Error) { - self.task.eventLoop.assertInEventLoop() - - let action = self.state.fail(error) + let action = self.loopBoundState.value.state.fail(error) self.executeFailAction0(action) + + self.loopBoundState.value.redirectTask?.fail(reason: error) } private func executeFailAction0(_ action: RequestBag.StateMachine.FailAction) { @@ -374,8 +390,7 @@ final class RequestBag { } func deadlineExceeded0() { - self.task.eventLoop.assertInEventLoop() - let action = self.state.deadlineExceeded() + let action = self.loopBoundState.value.state.deadlineExceeded() switch action { case .cancelScheduler(let scheduler): @@ -397,9 +412,6 @@ final class RequestBag { } extension RequestBag: HTTPSchedulableRequest, HTTPClientTaskDelegate { - var tlsConfiguration: TLSConfiguration? { - self.request.tlsConfiguration - } func requestWasQueued(_ scheduler: HTTPRequestScheduler) { if self.task.eventLoop.inEventLoop { diff --git a/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift b/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift new file mode 100644 index 000000000..25f1225e0 --- /dev/null +++ b/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// swift-format-ignore +// Note: Whitespace changes are used to workaround compiler bug +// https://github.com/swiftlang/swift/issues/79285 + +#if compiler(>=6.0) +@inlinable +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func asyncDo( + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ body: () async throws -> sending R, finally: sending @escaping ((any Error)?) async throws -> Void) async throws -> sending R { + let result: R + do { + result = try await body() + } catch { + // `body` failed, we need to invoke `finally` with the `error`. + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(error) + }.value + throw error + } + + // `body` succeeded, we need to invoke `finally` with `nil` (no error). + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(nil) + }.value + return result +} +#else +@inlinable +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func asyncDo( + _ body: () async throws -> R, + finally: @escaping @Sendable ((any Error)?) async throws -> Void +) async throws -> R { + let result: R + do { + result = try await body() + } catch { + // `body` failed, we need to invoke `finally` with the `error`. + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(error) + }.value + throw error + } + + // `body` succeeded, we need to invoke `finally` with `nil` (no error). + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(nil) + }.value + return result +} +#endif diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index abdd5bbc2..985755143 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -18,10 +18,10 @@ import NIOCore /// /// ``HTTPClientCopyingDelegate`` discards most parts of a HTTP response, but streams the body /// to the `chunkHandler` provided on ``init(chunkHandler:)``. This is mostly useful for testing. -public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { +public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate, Sendable { public typealias Response = Void - let chunkHandler: (ByteBuffer) -> EventLoopFuture + let chunkHandler: @Sendable (ByteBuffer) -> EventLoopFuture @preconcurrency public init(chunkHandler: @Sendable @escaping (ByteBuffer) -> EventLoopFuture) { diff --git a/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c b/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c index 5dfdc08a5..6342da89f 100644 --- a/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c +++ b/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c @@ -31,7 +31,7 @@ bool swiftahc_cshims_strptime(const char * string, const char * format, struct t bool swiftahc_cshims_strptime_l(const char * string, const char * format, struct tm * result, void * locale) { // The pointer cast is fine as long we make sure it really points to a locale_t. -#ifdef __musl__ +#if defined(__musl__) || defined(__ANDROID__) const char * firstNonProcessed = strptime(string, format, result); #else const char * firstNonProcessed = strptime_l(string, format, result, (locale_t)locale); diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift index 4bfa86d14..56a08b852 100644 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift +++ b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift @@ -76,6 +76,8 @@ final class AsyncAwaitEndToEndTests: XCTestCase { return } + XCTAssertEqual(response.url?.absoluteString, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) } @@ -98,6 +100,8 @@ final class AsyncAwaitEndToEndTests: XCTestCase { return } + XCTAssertEqual(response.url?.absoluteString, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) } @@ -522,6 +526,8 @@ final class AsyncAwaitEndToEndTests: XCTestCase { } func testConnectTimeout() { + let serverGroup = self.serverGroup! + let clientGroup = self.clientGroup! XCTAsyncTest(timeout: 60) { #if os(Linux) // 198.51.100.254 is reserved for documentation only and therefore should not accept any TCP connection @@ -538,7 +544,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - let serverChannel = try await ServerBootstrap(group: self.serverGroup) + let serverChannel = try await ServerBootstrap(group: serverGroup) .serverChannelOption(ChannelOptions.backlog, value: 1) .serverChannelOption(ChannelOptions.autoRead, value: false) .bind(host: "127.0.0.1", port: 0) @@ -547,7 +553,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertNoThrow(try serverChannel.close().wait()) } let port = serverChannel.localAddress!.port! - let firstClientChannel = try await ClientBootstrap(group: self.serverGroup) + let firstClientChannel = try await ClientBootstrap(group: serverGroup) .connect(host: "127.0.0.1", port: port) .get() defer { @@ -557,7 +563,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { #endif let httpClient = HTTPClient( - eventLoopGroupProvider: .shared(self.clientGroup), + eventLoopGroupProvider: .shared(clientGroup), configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150))) ) @@ -595,7 +601,9 @@ final class AsyncAwaitEndToEndTests: XCTestCase { defer { XCTAssertNoThrow(try serverGroup.syncShutdownGracefully()) } let server = ServerBootstrap(group: serverGroup) .childChannelInitializer { channel in - channel.pipeline.addHandler(NIOSSLServerHandler(context: sslContext)) + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } } let serverChannel = try await server.bind(host: "localhost", port: 0).get() defer { XCTAssertNoThrow(try serverChannel.close().wait()) } @@ -629,33 +637,33 @@ final class AsyncAwaitEndToEndTests: XCTestCase { func testDnsOverride() { XCTAsyncTest(timeout: 5) { - /// key + cert was created with the following code (depends on swift-certificates) - /// ``` - /// import X509 - /// import CryptoKit - /// import Foundation - /// - /// let privateKey = P384.Signing.PrivateKey() - /// let name = try DistinguishedName { - /// OrganizationName("Self Signed") - /// CommonName("localhost") - /// } - /// let certificate = try Certificate( - /// version: .v3, - /// serialNumber: .init(), - /// publicKey: .init(privateKey.publicKey), - /// notValidBefore: Date(), - /// notValidAfter: Date().advanced(by: 365 * 24 * 3600), - /// issuer: name, - /// subject: name, - /// signatureAlgorithm: .ecdsaWithSHA384, - /// extensions: try .init { - /// SubjectAlternativeNames([.dnsName("example.com")]) - /// try ExtendedKeyUsage([.serverAuth]) - /// }, - /// issuerPrivateKey: .init(privateKey) - /// ) - /// ``` + // key + cert was created with the following code (depends on swift-certificates) + // ``` + // import X509 + // import CryptoKit + // import Foundation + // + // let privateKey = P384.Signing.PrivateKey() + // let name = try DistinguishedName { + // OrganizationName("Self Signed") + // CommonName("localhost") + // } + // let certificate = try Certificate( + // version: .v3, + // serialNumber: .init(), + // publicKey: .init(privateKey.publicKey), + // notValidBefore: Date(), + // notValidAfter: Date().advanced(by: 365 * 24 * 3600), + // issuer: name, + // subject: name, + // signatureAlgorithm: .ecdsaWithSHA384, + // extensions: try .init { + // SubjectAlternativeNames([.dnsName("example.com")]) + // try ExtendedKeyUsage([.serverAuth]) + // }, + // issuerPrivateKey: .init(privateKey) + // ) + // ``` let certPath = Bundle.module.path(forResource: "example.com.cert", ofType: "pem")! let keyPath = Bundle.module.path(forResource: "example.com.private-key", ofType: "pem")! let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) @@ -732,9 +740,10 @@ final class AsyncAwaitEndToEndTests: XCTestCase { defer { XCTAssertNoThrow(try client.syncShutdown()) } let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://127.0.0.1/(bin.port)/redirect/target") + let redirectURL = "https://localhost/(bin.port)/echohostheader" request.headers.replaceOrAdd( name: "X-Target-Redirect-URL", - value: "https://localhost/(bin.port)/echohostheader" + value: redirectURL ) guard @@ -751,6 +760,8 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertNoThrow(maybeRequestInfo = try JSONDecoder().decode(RequestInfo.self, from: body)) guard let requestInfo = maybeRequestInfo else { return } + XCTAssertEqual(response.url?.absoluteString, redirectURL) + XCTAssertEqual(response.history.map(\.request.url), [request.url, redirectURL]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) XCTAssertEqual(requestInfo.data, "localhost:\(bin.port)") diff --git a/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift b/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift index cbab922a4..4a5c8d486 100644 --- a/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift +++ b/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift @@ -17,7 +17,7 @@ import NIOCore /// ``AsyncSequenceWriter`` is `Sendable` because its state is protected by a Lock @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -final class AsyncSequenceWriter: AsyncSequence, @unchecked Sendable { +final class AsyncSequenceWriter: AsyncSequence, @unchecked Sendable { typealias AsyncIterator = Iterator struct Iterator: AsyncIteratorProtocol { diff --git a/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift b/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift index 914d03612..5cc35bce8 100644 --- a/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift +++ b/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift @@ -59,7 +59,7 @@ extension EmbeddedChannel { } struct HTTP1TestTools { - let connection: HTTP1Connection + let connection: HTTP1Connection.SendableView let connectionDelegate: MockConnectionDelegate let readEventHandler: ReadEventHitHandler let logger: Logger @@ -87,8 +87,8 @@ extension EmbeddedChannel { let decoder = try self.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) let encoder = try self.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self) - let removeDecoderFuture = self.pipeline.removeHandler(decoder) - let removeEncoderFuture = self.pipeline.removeHandler(encoder) + let removeDecoderFuture = self.pipeline.syncOperations.removeHandler(decoder) + let removeEncoderFuture = self.pipeline.syncOperations.removeHandler(encoder) self.embeddedEventLoop.run() @@ -96,7 +96,7 @@ extension EmbeddedChannel { try removeEncoderFuture.wait() return .init( - connection: connection, + connection: connection.sendableView, connectionDelegate: connectionDelegate, readEventHandler: readEventHandler, logger: logger diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 53af0823d..0d871b7dc 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import Logging +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOHTTP1 @@ -833,43 +834,108 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { ) try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() - let request = MockHTTPExecutableRequest() // non empty body is important to trigger this bug as we otherwise finish the request in a single flush - request.requestFramingMetadata.body = .fixedSize(1) - request.raiseErrorIfUnimplementedMethodIsCalled = false + let request = MockHTTPExecutableRequest( + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .fixedSize(1)), + raiseErrorIfUnimplementedMethodIsCalled: false + ) channel.writeAndFlush(request, promise: nil) XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) } + + func testIdleWriteTimeoutOutsideOfRunningState() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + print("pipeline", embedded.pipeline) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard var request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + // start a request stream we'll never write to + let streamPromise = embedded.eventLoop.makePromise(of: Void.self) + let streamCallback = { @Sendable (streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture in + streamPromise.futureResult + } + request.body = .init(contentLength: nil, stream: streamCallback) + + let accumulator = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests( + idleReadTimeout: .milliseconds(10), + idleWriteTimeout: .milliseconds(2) + ), + delegate: accumulator + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + testUtils.connection.executeRequest(requestBag) + + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) + + // close the pipeline to simulate a server-side close + // note this happens before we write so the idle write timeout is still running + try! embedded.pipeline.close().wait() + + // advance time to trigger the idle write timeout + // and ensure that the state machine can tolerate this + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } } -class TestBackpressureWriter { +final class TestBackpressureWriter: Sendable { let eventLoop: EventLoop let parts: Int var finishFuture: EventLoopFuture { self.finishPromise.futureResult } private let finishPromise: EventLoopPromise - private(set) var written: Int = 0 - private var channelIsWritable: Bool = false + private struct State { + var written = 0 + var channelIsWritable = false + } + + var written: Int { + self.state.value.written + } + + private let state: NIOLoopBoundBox init(eventLoop: EventLoop, parts: Int) { self.eventLoop = eventLoop self.parts = parts - + self.state = .makeBoxSendingValue(State(), eventLoop: eventLoop) self.finishPromise = eventLoop.makePromise(of: Void.self) } func start(writer: HTTPClient.Body.StreamWriter, expectedErrors: [HTTPClientError] = []) -> EventLoopFuture { + @Sendable func recursive() { XCTAssert(self.eventLoop.inEventLoop) - XCTAssert(self.channelIsWritable) - if self.written == self.parts { + XCTAssert(self.state.value.channelIsWritable) + if self.state.value.written == self.parts { self.finishPromise.succeed(()) } else { self.eventLoop.execute { let future = writer.write(.byteBuffer(.init(bytes: [0, 1]))) - self.written += 1 + self.state.value.written += 1 future.whenComplete { result in switch result { case .success: @@ -896,14 +962,14 @@ class TestBackpressureWriter { } func writabilityChanged(_ newValue: Bool) { - self.channelIsWritable = newValue + self.state.value.channelIsWritable = newValue } } -class ResponseBackpressureDelegate: HTTPClientResponseDelegate { +final class ResponseBackpressureDelegate: HTTPClientResponseDelegate { typealias Response = Void - enum State { + enum State: Sendable { case consuming(EventLoopPromise) case waitingForRemote(CircularBuffer>) case buffering((ByteBuffer?, EventLoopPromise)?) @@ -911,21 +977,20 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } let eventLoop: EventLoop - private var state: State = .buffering(nil) + private let state: NIOLoopBoundBox init(eventLoop: EventLoop) { self.eventLoop = eventLoop - - self.state = .consuming(self.eventLoop.makePromise(of: Void.self)) + self.state = .makeBoxSendingValue(.consuming(eventLoop.makePromise(of: Void.self)), eventLoop: eventLoop) } func next() -> EventLoopFuture { - switch self.state { + switch self.state.value { case .consuming(let backpressurePromise): var promiseBuffer = CircularBuffer>() let newPromise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(newPromise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) backpressurePromise.succeed(()) return newPromise.futureResult @@ -936,18 +1001,18 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { ) let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(promise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) return promise.futureResult case .buffering(.none): var promiseBuffer = CircularBuffer>() let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(promise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) return promise.futureResult case .buffering(.some((let buffer, let promise))): - self.state = .buffering(nil) + self.state.value = .buffering(nil) promise.succeed(()) return self.eventLoop.makeSucceededFuture(buffer) @@ -957,7 +1022,7 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - switch self.state { + switch self.state.value { case .consuming(let backpressurePromise): return backpressurePromise.futureResult @@ -970,7 +1035,7 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - switch self.state { + switch self.state.value { case .waitingForRemote(var promiseBuffer): assert( !promiseBuffer.isEmpty, @@ -979,18 +1044,18 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { let promise = promiseBuffer.removeFirst() if promiseBuffer.isEmpty { let newBackpressurePromise = self.eventLoop.makePromise(of: Void.self) - self.state = .consuming(newBackpressurePromise) + self.state.value = .consuming(newBackpressurePromise) promise.succeed(buffer) return newBackpressurePromise.futureResult } else { - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) promise.succeed(buffer) return self.eventLoop.makeSucceededVoidFuture() } case .buffering(.none): let promise = self.eventLoop.makePromise(of: Void.self) - self.state = .buffering((buffer, promise)) + self.state.value = .buffering((buffer, promise)) return promise.futureResult case .buffering(.some): @@ -1004,15 +1069,15 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didFinishRequest(task: HTTPClient.Task) throws { - switch self.state { + switch self.state.value { case .waitingForRemote(let promiseBuffer): for promise in promiseBuffer { promise.succeed(.none) } - self.state = .done + self.state.value = .done case .buffering(.none): - self.state = .done + self.state.value = .done case .done, .consuming: preconditionFailure("Invalid state: \(self.state)") @@ -1038,7 +1103,7 @@ class ReadEventHitHandler: ChannelOutboundHandler { } } -final class FailEndHandler: ChannelOutboundHandler { +final class FailEndHandler: ChannelOutboundHandler, Sendable { typealias OutboundIn = HTTPClientRequestPart typealias OutboundOut = HTTPClientRequestPart diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index 18831d32f..1c6e9659f 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -101,6 +101,26 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) } + func testWriteTimeoutAfterErrorDoesntCrash() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) + + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) + let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + + struct MyError: Error, Equatable {} + XCTAssertEqual(state.errorHappened(MyError()), .failRequest(MyError(), .close(nil))) + + // Primarily we care that we don't crash here + XCTAssertEqual(state.idleWriteTimeoutTriggered(), .wait) + } + func testAConnectionCloseHeaderInTheRequestLeadsToConnectionCloseAfterRequest() { var state = HTTP1ConnectionStateMachine() XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift index 5f980bccb..53001b64b 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift @@ -48,7 +48,7 @@ class HTTP1ConnectionTests: XCTestCase { ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) - XCTAssertNoThrow(try connection?.close().wait()) + XCTAssertNoThrow(try connection?.sendableView.close().wait()) embedded.embeddedEventLoop.run() XCTAssert(!embedded.isActive) } @@ -108,8 +108,7 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try server.stop()) } let logger = Logger(label: "test") - let delegate = MockHTTP1ConnectionDelegate() - delegate.closePromise = clientEL.makePromise(of: Void.self) + let delegate = MockHTTP1ConnectionDelegate(closePromise: clientEL.makePromise()) let connection = try! ClientBootstrap(group: clientEL) .connect(to: .init(ipAddress: "127.0.0.1", port: server.serverPort)) @@ -120,7 +119,7 @@ class HTTP1ConnectionTests: XCTestCase { delegate: delegate, decompression: .disabled, logger: logger - ) + ).sendableView } .wait() @@ -223,16 +222,16 @@ class HTTP1ConnectionTests: XCTestCase { ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? + var maybeConnection: HTTP1Connection.SendableView? XCTAssertNoThrow( - maybeConnection = try eventLoop.submit { + maybeConnection = try eventLoop.submit { [maybeChannel] in try HTTP1Connection.start( channel: XCTUnwrap(maybeChannel), connectionID: 0, delegate: connectionDelegate, decompression: .disabled, logger: logger - ) + ).sendableView }.wait() ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } @@ -287,16 +286,16 @@ class HTTP1ConnectionTests: XCTestCase { ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? + var maybeConnection: HTTP1Connection.SendableView? XCTAssertNoThrow( - maybeConnection = try eventLoop.submit { + maybeConnection = try eventLoop.submit { [maybeChannel] in try HTTP1Connection.start( channel: XCTUnwrap(maybeChannel), connectionID: 0, delegate: connectionDelegate, decompression: .disabled, logger: logger - ) + ).sendableView }.wait() ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } @@ -372,16 +371,16 @@ class HTTP1ConnectionTests: XCTestCase { ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? + var maybeConnection: HTTP1Connection.SendableView? XCTAssertNoThrow( - maybeConnection = try eventLoop.submit { + maybeConnection = try eventLoop.submit { [maybeChannel] in try HTTP1Connection.start( channel: XCTUnwrap(maybeChannel), connectionID: 0, delegate: connectionDelegate, decompression: .disabled, logger: logger - ) + ).sendableView }.wait() ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } @@ -454,7 +453,7 @@ class HTTP1ConnectionTests: XCTestCase { ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end @@ -523,7 +522,7 @@ class HTTP1ConnectionTests: XCTestCase { ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end @@ -626,7 +625,7 @@ class HTTP1ConnectionTests: XCTestCase { ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) let responseString = """ HTTP/1.0 200 OK\r\n\ @@ -654,31 +653,31 @@ class HTTP1ConnectionTests: XCTestCase { // bytes a ready to be read as well. This will allow us to test if subsequent reads // are waiting for backpressure promise. func testDownloadStreamingBackpressure() { - class BackpressureTestDelegate: HTTPClientResponseDelegate { + final class BackpressureTestDelegate: HTTPClientResponseDelegate { typealias Response = Void - var _reads = 0 - var _channel: Channel? + private struct State: Sendable { + var reads = 0 + var channel: Channel? + } + + private let state = NIOLockedValueBox(State()) + + var reads: Int { + self.state.withLockedValue { $0.reads } + } - let lock: NIOLock let backpressurePromise: EventLoopPromise let messageReceived: EventLoopPromise init(eventLoop: EventLoop) { - self.lock = NIOLock() self.backpressurePromise = eventLoop.makePromise() self.messageReceived = eventLoop.makePromise() } - var reads: Int { - self.lock.withLock { - self._reads - } - } - func willExecuteOnChannel(_ channel: Channel) { - self.lock.withLock { - self._channel = channel + self.state.withLockedValue { + $0.channel = channel } } @@ -688,8 +687,8 @@ class HTTP1ConnectionTests: XCTestCase { func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { // We count a number of reads received. - self.lock.withLock { - self._reads += 1 + self.state.withLockedValue { + $0.reads += 1 } // We need to notify the test when first byte of the message is arrived. self.messageReceived.succeed(()) @@ -721,8 +720,8 @@ class HTTP1ConnectionTests: XCTestCase { let buffer = context.channel.allocator.buffer(string: "1234") context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) - self.endFuture.hop(to: context.eventLoop).whenSuccess { - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + self.endFuture.hop(to: context.eventLoop).assumeIsolated().whenSuccess { + context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil) } } } @@ -753,7 +752,7 @@ class HTTP1ConnectionTests: XCTestCase { ) guard let channel = maybeChannel else { return XCTFail("Expected to have a channel at this point") } let connectionDelegate = MockConnectionDelegate() - var maybeConnection: HTTP1Connection? + var maybeConnection: HTTP1Connection.SendableView? XCTAssertNoThrow( maybeConnection = try channel.eventLoop.submit { try HTTP1Connection.start( @@ -762,7 +761,7 @@ class HTTP1ConnectionTests: XCTestCase { delegate: connectionDelegate, decompression: .disabled, logger: logger - ) + ).sendableView }.wait() ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point") } @@ -802,15 +801,20 @@ class HTTP1ConnectionTests: XCTestCase { } } -class MockHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { - var releasePromise: EventLoopPromise? - var closePromise: EventLoopPromise? +final class MockHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { + let releasePromise: EventLoopPromise? + let closePromise: EventLoopPromise? - func http1ConnectionReleased(_: HTTP1Connection) { + init(releasePromise: EventLoopPromise? = nil, closePromise: EventLoopPromise? = nil) { + self.releasePromise = releasePromise + self.closePromise = closePromise + } + + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { self.releasePromise?.succeed(()) } - func http1ConnectionClosed(_: HTTP1Connection) { + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { self.closePromise?.succeed(()) } } @@ -875,38 +879,40 @@ class AfterRequestCloseConnectionChannelHandler: ChannelInboundHandler { context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() - context.eventLoop.scheduleTask(in: .milliseconds(20)) { + context.eventLoop.assumeIsolated().scheduleTask(in: .milliseconds(20)) { context.close(promise: nil) } } } } -class MockConnectionDelegate: HTTP1ConnectionDelegate { - private var lock = NIOLock() +final class MockConnectionDelegate: HTTP1ConnectionDelegate { + private let counts = NIOLockedValueBox(Counts()) - private var _hitConnectionReleased = 0 - private var _hitConnectionClosed = 0 + private struct Counts: Sendable { + var hitConnectionReleased = 0 + var hitConnectionClosed = 0 + } var hitConnectionReleased: Int { - self.lock.withLock { self._hitConnectionReleased } + self.counts.withLockedValue { $0.hitConnectionReleased } } var hitConnectionClosed: Int { - self.lock.withLock { self._hitConnectionClosed } + self.counts.withLockedValue { $0.hitConnectionClosed } } init() {} - func http1ConnectionReleased(_: HTTP1Connection) { - self.lock.withLock { - self._hitConnectionReleased += 1 + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { + self.counts.withLockedValue { + $0.hitConnectionReleased += 1 } } - func http1ConnectionClosed(_: HTTP1Connection) { - self.lock.withLock { - self._hitConnectionClosed += 1 + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { + self.counts.withLockedValue { + $0.hitConnectionClosed += 1 } } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index 1f5f1b4c0..71f7f3d1a 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -568,10 +568,11 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { ) try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() - let request = MockHTTPExecutableRequest() // non empty body is important to trigger this bug as we otherwise finish the request in a single flush - request.requestFramingMetadata.body = .fixedSize(1) - request.raiseErrorIfUnimplementedMethodIsCalled = false + let request = MockHTTPExecutableRequest( + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .fixedSize(1)), + raiseErrorIfUnimplementedMethodIsCalled: false + ) channel.writeAndFlush(request, promise: nil) XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift index d6bc2de14..183a227bd 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift @@ -14,6 +14,7 @@ import AsyncHTTPClient // NOT @testable - tests that really need @testable go into HTTP2ClientInternalTests.swift import Logging +import NIOConcurrencyHelpers import NIOCore import NIOFoundationCompat import NIOHTTP1 @@ -283,15 +284,16 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost/(bin.port)")) guard let request = maybeRequest else { return } - var task: HTTPClient.Task! + let taskBox = NIOLockedValueBox?>(nil) let delegate = HeadReceivedCallback { _ in // request is definitely running because we just received a head from the server - task.cancel() + taskBox.withLockedValue { $0 }!.cancel() } - task = client.execute( + let task = client.execute( request: request, delegate: delegate ) + taskBox.withLockedValue { $0 = task } XCTAssertThrowsError(try task.futureResult.timeout(after: .seconds(2)).wait()) { XCTAssertEqualTypeAndValue($0, HTTPClientError.cancelled) @@ -360,18 +362,20 @@ class HTTP2ClientTests: XCTestCase { guard let request = maybeRequest else { return } let tasks = (0..<100).map { _ -> HTTPClient.Task in - var task: HTTPClient.Task! + let taskBox = NIOLockedValueBox?>(nil) + let delegate = HeadReceivedCallback { _ in // request is definitely running because we just received a head from the server cancelPool.next().execute { // canceling from a different thread - task.cancel() + taskBox.withLockedValue { $0 }!.cancel() } } - task = client.execute( + let task = client.execute( request: request, delegate: delegate ) + taskBox.withLockedValue { $0 = task } return task } @@ -547,8 +551,8 @@ class HTTP2ClientTests: XCTestCase { private final class HeadReceivedCallback: HTTPClientResponseDelegate { typealias Response = Void - private let didReceiveHeadCallback: (HTTPResponseHead) -> Void - init(didReceiveHead: @escaping (HTTPResponseHead) -> Void) { + private let didReceiveHeadCallback: @Sendable (HTTPResponseHead) -> Void + init(didReceiveHead: @escaping @Sendable (HTTPResponseHead) -> Void) { self.didReceiveHeadCallback = didReceiveHead } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift index a50f1ab54..3244e2b5a 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift @@ -44,7 +44,7 @@ class HTTP2ConnectionTests: XCTestCase { decompression: .disabled, maximumConnectionUses: nil, logger: logger - ).wait() + ).map { _ in }.nonisolated().wait() ) } @@ -70,7 +70,7 @@ class HTTP2ConnectionTests: XCTestCase { XCTAssertThrowsError(try startFuture.wait()) // should not crash - connection.shutdown() + connection.sendableView.shutdown() } func testSimpleGetRequest() { @@ -83,7 +83,7 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? + var maybeHTTP2Connection: HTTP2Connection.SendableView? XCTAssertNoThrow( maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( to: httpBin.port, @@ -142,7 +142,7 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? + var maybeHTTP2Connection: HTTP2Connection.SendableView? XCTAssertNoThrow( maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( to: httpBin.port, @@ -210,7 +210,7 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? + var maybeHTTP2Connection: HTTP2Connection.SendableView? XCTAssertNoThrow( maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( to: httpBin.port, @@ -277,7 +277,7 @@ class HTTP2ConnectionTests: XCTestCase { func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.dataArrivedPromise.succeed(()) - self.triggerResponseFuture.hop(to: context.eventLoop).whenSuccess { + self.triggerResponseFuture.hop(to: context.eventLoop).assumeIsolated().whenSuccess { switch self.unwrapInboundIn(data) { case .head: context.write(self.wrapOutboundOut(.head(.init(version: .http2, status: .ok))), promise: nil) @@ -305,7 +305,7 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? + var maybeHTTP2Connection: HTTP2Connection.SendableView? XCTAssertNoThrow( maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( to: httpBin.port, @@ -385,7 +385,7 @@ class HTTP2ConnectionTests: XCTestCase { } } -class TestConnectionCreator { +final class TestConnectionCreator { enum Error: Swift.Error { case alreadyCreatingAnotherConnection case wantedHTTP2ConnectionButGotHTTP1 @@ -394,12 +394,11 @@ class TestConnectionCreator { enum State { case idle - case waitingForHTTP1Connection(EventLoopPromise) - case waitingForHTTP2Connection(EventLoopPromise) + case waitingForHTTP1Connection(EventLoopPromise) + case waitingForHTTP2Connection(EventLoopPromise) } - private var state: State = .idle - private let lock = NIOLock() + private let lock = NIOLockedValueBox(.idle) init() {} @@ -409,7 +408,7 @@ class TestConnectionCreator { connectionID: HTTPConnectionPool.Connection.ID = 0, on eventLoop: EventLoop, logger: Logger = .init(label: "test") - ) throws -> HTTP1Connection { + ) throws -> HTTP1Connection.SendableView { let request = try! HTTPClient.Request(url: "https://localhost/(port)") var tlsConfiguration = TLSConfiguration.makeClientConfiguration() @@ -423,13 +422,13 @@ class TestConnectionCreator { sslContextCache: .init() ) - let promise = try self.lock.withLock { () -> EventLoopPromise in - guard case .idle = self.state else { + let promise = try self.lock.withLockedValue { state in + guard case .idle = state else { throw Error.alreadyCreatingAnotherConnection } - let promise = eventLoop.makePromise(of: HTTP1Connection.self) - self.state = .waitingForHTTP1Connection(promise) + let promise = eventLoop.makePromise(of: HTTP1Connection.SendableView.self) + state = .waitingForHTTP1Connection(promise) return promise } @@ -452,7 +451,7 @@ class TestConnectionCreator { connectionID: HTTPConnectionPool.Connection.ID = 0, on eventLoop: EventLoop, logger: Logger = .init(label: "test") - ) throws -> HTTP2Connection { + ) throws -> HTTP2Connection.SendableView { let request = try! HTTPClient.Request(url: "https://localhost/(port)") var tlsConfiguration = TLSConfiguration.makeClientConfiguration() @@ -466,13 +465,13 @@ class TestConnectionCreator { sslContextCache: .init() ) - let promise = try self.lock.withLock { () -> EventLoopPromise in - guard case .idle = self.state else { + let promise = try self.lock.withLockedValue { state in + guard case .idle = state else { throw Error.alreadyCreatingAnotherConnection } - let promise = eventLoop.makePromise(of: HTTP2Connection.self) - self.state = .waitingForHTTP2Connection(promise) + let promise = eventLoop.makePromise(of: HTTP2Connection.SendableView.self) + state = .waitingForHTTP2Connection(promise) return promise } @@ -491,7 +490,7 @@ class TestConnectionCreator { } extension TestConnectionCreator: HTTPConnectionRequester { - enum EitherPromiseWrapper { + enum EitherPromiseWrapper: Sendable { case succeed(EventLoopPromise, SucceedType) case fail(EventLoopPromise, Error) @@ -505,37 +504,38 @@ extension TestConnectionCreator: HTTPConnectionRequester { } } - func http1ConnectionCreated(_ connection: HTTP1Connection) { - let wrapper = self.lock.withLock { () -> (EitherPromiseWrapper) in + func http1ConnectionCreated(_ connection: HTTP1Connection.SendableView) { + let wrapper: EitherPromiseWrapper = self.lock + .withLockedValue { state in - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .succeed(promise, connection) + switch state { + case .waitingForHTTP1Connection(let promise): + return .succeed(promise, connection) - case .waitingForHTTP2Connection(let promise): - return .fail(promise, Error.wantedHTTP2ConnectionButGotHTTP1) + case .waitingForHTTP2Connection(let promise): + return .fail(promise, Error.wantedHTTP2ConnectionButGotHTTP1) - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.complete() } - func http2ConnectionCreated(_ connection: HTTP2Connection, maximumStreams: Int) { - let wrapper = self.lock.withLock { () -> (EitherPromiseWrapper) in + func http2ConnectionCreated(_ connection: HTTP2Connection.SendableView, maximumStreams: Int) { + let wrapper: EitherPromiseWrapper = self.lock + .withLockedValue { state in + switch state { + case .waitingForHTTP1Connection(let promise): + return .fail(promise, Error.wantedHTTP1ConnectionButGotHTTP2) - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .fail(promise, Error.wantedHTTP1ConnectionButGotHTTP2) + case .waitingForHTTP2Connection(let promise): + return .succeed(promise, connection) - case .waitingForHTTP2Connection(let promise): - return .succeed(promise, connection) - - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.complete() } @@ -554,19 +554,20 @@ extension TestConnectionCreator: HTTPConnectionRequester { } func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Swift.Error) { - let wrapper = self.lock.withLock { () -> (FailPromiseWrapper) in + let wrapper: FailPromiseWrapper = self.lock + .withLockedValue { state in - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .type1(promise) + switch state { + case .waitingForHTTP1Connection(let promise): + return .type1(promise) - case .waitingForHTTP2Connection(let promise): - return .type2(promise) + case .waitingForHTTP2Connection(let promise): + return .type2(promise) - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.fail(error) } @@ -575,76 +576,78 @@ extension TestConnectionCreator: HTTPConnectionRequester { } } -class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { +final class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { var hitStreamClosed: Int { - self.lock.withLock { self._hitStreamClosed } + self.lock.withLockedValue { $0.hitStreamClosed } } var hitGoAwayReceived: Int { - self.lock.withLock { self._hitGoAwayReceived } + self.lock.withLockedValue { $0.hitGoAwayReceived } } var hitConnectionClosed: Int { - self.lock.withLock { self._hitConnectionClosed } + self.lock.withLockedValue { $0.hitConnectionClosed } } var maxStreamSetting: Int { - self.lock.withLock { self._maxStreamSetting } + self.lock.withLockedValue { $0.maxStreamSetting } } - private let lock = NIOLock() - private var _hitStreamClosed: Int = 0 - private var _hitGoAwayReceived: Int = 0 - private var _hitConnectionClosed: Int = 0 - private var _maxStreamSetting: Int = 100 + private let lock = NIOLockedValueBox(.init()) + private struct Counts { + var hitStreamClosed: Int = 0 + var hitGoAwayReceived: Int = 0 + var hitConnectionClosed: Int = 0 + var maxStreamSetting: Int = 100 + } init() {} - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) {} + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) {} - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { - self.lock.withLock { - self._hitStreamClosed += 1 + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) { + self.lock.withLockedValue { + $0.hitStreamClosed += 1 } } - func http2ConnectionGoAwayReceived(_: HTTP2Connection) { - self.lock.withLock { - self._hitGoAwayReceived += 1 + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) { + self.lock.withLockedValue { + $0.hitGoAwayReceived += 1 } } - func http2ConnectionClosed(_: HTTP2Connection) { - self.lock.withLock { - self._hitConnectionClosed += 1 + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { + self.lock.withLockedValue { + $0.hitConnectionClosed += 1 } } } final class EmptyHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) { + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) { preconditionFailure("Unimplemented") } - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) { preconditionFailure("Unimplemented") } - func http2ConnectionGoAwayReceived(_: HTTP2Connection) { + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } - func http2ConnectionClosed(_: HTTP2Connection) { + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } } final class EmptyHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { - func http1ConnectionReleased(_: HTTP1Connection) { + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } - func http1ConnectionClosed(_: HTTP1Connection) { + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift b/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift new file mode 100644 index 000000000..a7cc1f454 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift @@ -0,0 +1,101 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import NIO +import NIOFoundationCompat +import NIOHTTP1 +import XCTest + +final class HTTPClientStructuredConcurrencyTests: XCTestCase { + func testDoNothingWorks() async throws { + let actual = try await HTTPClient.withHTTPClient { httpClient in + "OK" + } + XCTAssertEqual("OK", actual) + } + + func testShuttingDownTheClientInBodyLeadsToError() async { + do { + let actual = try await HTTPClient.withHTTPClient { httpClient in + try await httpClient.shutdown() + return "OK" + } + XCTFail("Expected error, got \(actual)") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } + + func testBasicRequest() async throws { + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let actualBytes = try await HTTPClient.withHTTPClient { httpClient in + let response = try await httpClient.get(url: httpBin.baseURL).get() + XCTAssertEqual(response.status, .ok) + return response.body ?? ByteBuffer(string: "n/a") + } + let actual = try JSONDecoder().decode(RequestInfo.self, from: actualBytes) + + XCTAssertGreaterThanOrEqual(actual.requestNumber, 0) + XCTAssertGreaterThanOrEqual(actual.connectionNumber, 0) + } + + func testClientIsShutDownAfterReturn() async throws { + let leakedClient = try await HTTPClient.withHTTPClient { httpClient in + httpClient + } + do { + try await leakedClient.shutdown() + XCTFail("unexpected, shutdown should have failed") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } + + func testClientIsShutDownOnThrowAlso() async throws { + struct TestError: Error { + var httpClient: HTTPClient + } + + let leakedClient: HTTPClient + do { + try await HTTPClient.withHTTPClient { httpClient in + throw TestError(httpClient: httpClient) + } + XCTFail("unexpected, shutdown should have failed") + return + } catch let error as TestError { + // OK + leakedClient = error.httpClient + } catch { + XCTFail("unexpected error: \(error)") + return + } + + do { + try await leakedClient.shutdown() + XCTFail("unexpected, shutdown should have failed") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 5b70699a0..634efc14c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -164,10 +164,10 @@ class HTTPClientInternalTests: XCTestCase { } func testChannelAndDelegateOnDifferentEventLoops() throws { - class Delegate: HTTPClientResponseDelegate { + final class Delegate: HTTPClientResponseDelegate { typealias Response = ([Message], [Message]) - enum Message { + enum Message: Sendable { case head(HTTPResponseHead) case bodyPart(ByteBuffer) case sentRequestHead(HTTPRequestHead) @@ -176,33 +176,51 @@ class HTTPClientInternalTests: XCTestCase { case error(Error) } - var receivedMessages: [Message] = [] - var sentMessages: [Message] = [] + private struct Messages: Sendable { + var received: [Message] = [] + var sent: [Message] = [] + } + + private let messages: NIOLoopBoundBox + + var receivedMessages: [Message] { + get { + self.messages.value.received + } + set { + self.messages.value.received = newValue + } + } + var sentMessages: [Message] { + get { + self.messages.value.sent + } + set { + self.messages.value.sent = newValue + } + } private let eventLoop: EventLoop private let randoEL: EventLoop init(expectedEventLoop: EventLoop, randomOtherEventLoop: EventLoop) { self.eventLoop = expectedEventLoop self.randoEL = randomOtherEventLoop + self.messages = .makeBoxSendingValue(Messages(), eventLoop: expectedEventLoop) } func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequestHead(head)) } func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequestPart(part)) } func didSendRequest(task: HTTPClient.Task) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequest) } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.eventLoop.assertInEventLoop() self.receivedMessages.append(.error(error)) } @@ -210,7 +228,6 @@ class HTTPClientInternalTests: XCTestCase { task: HTTPClient.Task, _ head: HTTPResponseHead ) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() self.receivedMessages.append(.head(head)) return self.randoEL.makeSucceededFuture(()) } @@ -219,14 +236,12 @@ class HTTPClientInternalTests: XCTestCase { task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() self.receivedMessages.append(.bodyPart(buffer)) return self.randoEL.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Response { - self.eventLoop.assertInEventLoop() - return (self.receivedMessages, self.sentMessages) + (self.receivedMessages, self.sentMessages) } } @@ -460,11 +475,15 @@ class HTTPClientInternalTests: XCTestCase { } func testConnectErrorCalloutOnCorrectEL() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void let expectedEL: EventLoop - var receivedError: Bool = false + let _receivedError = NIOLockedValueBox(false) + + var receivedError: Bool { + self._receivedError.withLockedValue { $0 } + } init(expectedEL: EventLoop) { self.expectedEL = expectedEL @@ -473,7 +492,7 @@ class HTTPClientInternalTests: XCTestCase { func didFinishRequest(task: HTTPClient.Task) throws {} func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.receivedError = true + self._receivedError.withLockedValue { $0 = true } XCTAssertTrue(self.expectedEL.inEventLoop) } } @@ -658,7 +677,8 @@ class HTTPClientInternalTests: XCTestCase { ).futureResult } _ = try EventLoopFuture.whenAllSucceed(resultFutures, on: self.clientGroup.next()).wait() - let threadPools = delegates.map { $0.fileIOThreadPool } + + let threadPools = delegates.map { $0._fileIOThreadPool } let firstThreadPool = threadPools.first ?? nil XCTAssert(threadPools.dropFirst().allSatisfy { $0 === firstThreadPool }) } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift index a2cc3b108..54467aab7 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import Algorithms +import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import XCTest @@ -493,7 +494,7 @@ class HTTPClientRequestTests: XCTestCase { request.method = .POST let asyncSequence = ByteBuffer(string: "post body") .readableBytesView - .chunks(ofCount: 2) + .uncheckedSendableChunks(ofCount: 2) .async .map { ByteBuffer($0) } @@ -541,7 +542,7 @@ class HTTPClientRequestTests: XCTestCase { request.method = .POST let asyncSequence = ByteBuffer(string: "post body") .readableBytesView - .chunks(ofCount: 2) + .uncheckedSendableChunks(ofCount: 2) .async .map { ByteBuffer($0) } @@ -619,7 +620,7 @@ class HTTPClientRequestTests: XCTestCase { func testChunkingSequenceThatDoesNotImplementWithContiguousStorageIfAvailable() async throws { let bagOfBytesToByteBufferConversionChunkSize = 8 let body = try await HTTPClientRequest.Body._bytes( - AnySequence( + AnySendableSequence( Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) ), @@ -729,17 +730,17 @@ extension HTTPClient.Body { func collect() -> EventLoopFuture<[ByteBuffer]> { let eelg = EmbeddedEventLoopGroup(loops: 1) let el = eelg.next() - var body = [ByteBuffer]() + let body = NIOLockedValueBox<[ByteBuffer]>([]) let writer = StreamWriter { switch $0 { case .byteBuffer(let byteBuffer): - body.append(byteBuffer) + body.withLockedValue { $0.append(byteBuffer) } case .fileRegion: fatalError("file region not supported") } return el.makeSucceededVoidFuture() } - return self.stream(writer).map { _ in body } + return self.stream(writer).map { _ in body.withLockedValue { $0 } } } } @@ -766,8 +767,9 @@ extension Optional where Wrapped == HTTPClientRequest.Prepared.Body { throw LengthMismatch(announcedLength: announcedLength, actualLength: Int64(buffer.readableBytes)) } return buffer - case .asyncSequence(length: let announcedLength, let generate): + case .asyncSequence(length: let announcedLength, let makeAsyncIterator): var accumulatedBuffer = ByteBuffer() + let generate = makeAsyncIterator() while var buffer = try await generate(ByteBufferAllocator()) { accumulatedBuffer.writeBuffer(&buffer) } @@ -783,3 +785,35 @@ extension Optional where Wrapped == HTTPClientRequest.Prepared.Body { } } } + +// swift-algorithms hasn't adopted Sendable yet. By inspection ChunksOfCountCollection should be +// Sendable assuming the underlying collection is. This wrapper allows us to avoid a blanket +// preconcurrency import of the Algorithms module. +struct UncheckedSendableChunksOfCountCollection: Collection, @unchecked Sendable +where Base: Sendable { + typealias Element = Base.SubSequence + typealias Index = ChunksOfCountCollection.Index + + private let underlying: ChunksOfCountCollection + + init(_ underlying: ChunksOfCountCollection) { + self.underlying = underlying + } + + var startIndex: Index { self.underlying.startIndex } + var endIndex: Index { self.underlying.endIndex } + + subscript(position: Index) -> Base.SubSequence { + self.underlying[position] + } + + func index(after i: Index) -> Index { + self.underlying.index(after: i) + } +} + +extension Collection where Self: Sendable { + func uncheckedSendableChunks(ofCount count: Int) -> UncheckedSendableChunksOfCountCollection { + UncheckedSendableChunksOfCountCollection(self.chunks(ofCount: count)) + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index da2046b81..f9917c885 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -39,6 +39,8 @@ import locale_h import Darwin #elseif canImport(Musl) import Musl +#elseif canImport(Android) +import Android #elseif canImport(Glibc) import Glibc #endif @@ -93,15 +95,13 @@ func withCLocaleSetToGerman(_ body: () throws -> Void) throws { try body() } -class TestHTTPDelegate: HTTPClientResponseDelegate { +final class TestHTTPDelegate: HTTPClientResponseDelegate { typealias Response = Void init(backpressureEventLoop: EventLoop? = nil) { - self.backpressureEventLoop = backpressureEventLoop + self.state = NIOLockedValueBox(MutableState(backpressureEventLoop: backpressureEventLoop)) } - var backpressureEventLoop: EventLoop? - enum State { case idle case head(HTTPResponseHead) @@ -110,77 +110,96 @@ class TestHTTPDelegate: HTTPClientResponseDelegate { case error(Error) } - var state = State.idle + struct MutableState: Sendable { + var state: State = .idle + var backpressureEventLoop: EventLoop? + } + + let state: NIOLockedValueBox func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.state = .head(head) - return (self.backpressureEventLoop ?? task.eventLoop).makeSucceededFuture(()) + let eventLoop = self.state.withLockedValue { + $0.state = .head(head) + return ($0.backpressureEventLoop ?? task.eventLoop) + } + + return eventLoop.makeSucceededVoidFuture() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - switch self.state { - case .head(let head): - self.state = .body(head, buffer) - case .body(let head, var body): - var buffer = buffer - body.writeBuffer(&buffer) - self.state = .body(head, body) - default: - preconditionFailure("expecting head or body") + let eventLoop = self.state.withLockedValue { + switch $0.state { + case .head(let head): + $0.state = .body(head, buffer) + case .body(let head, var body): + var buffer = buffer + body.writeBuffer(&buffer) + $0.state = .body(head, body) + default: + preconditionFailure("expecting head or body") + } + return ($0.backpressureEventLoop ?? task.eventLoop) } - return (self.backpressureEventLoop ?? task.eventLoop).makeSucceededFuture(()) + + return eventLoop.makeSucceededVoidFuture() } func didFinishRequest(task: HTTPClient.Task) throws {} } -class CountingDelegate: HTTPClientResponseDelegate { +final class CountingDelegate: HTTPClientResponseDelegate { typealias Response = Int - var count = 0 + private let _count = NIOLockedValueBox(0) func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { let str = buffer.getString(at: 0, length: buffer.readableBytes) if str?.starts(with: "id:") ?? false { - self.count += 1 + self._count.withLockedValue { $0 += 1 } } return task.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Int { - self.count + self._count.withLockedValue { $0 } } } -class DelayOnHeadDelegate: HTTPClientResponseDelegate { +final class DelayOnHeadDelegate: HTTPClientResponseDelegate { typealias Response = ByteBuffer let eventLoop: EventLoop - let didReceiveHead: (HTTPResponseHead, EventLoopPromise) -> Void - - private var data: ByteBuffer + let didReceiveHead: @Sendable (HTTPResponseHead, EventLoopPromise) -> Void - private var mayReceiveData = false + struct State: Sendable { + var data: ByteBuffer + var mayReceiveData = false + var expectError = false + } - private var expectError = false + private let state: NIOLockedValueBox - init(eventLoop: EventLoop, didReceiveHead: @escaping (HTTPResponseHead, EventLoopPromise) -> Void) { + init(eventLoop: EventLoop, didReceiveHead: @escaping @Sendable (HTTPResponseHead, EventLoopPromise) -> Void) { self.eventLoop = eventLoop self.didReceiveHead = didReceiveHead - self.data = ByteBuffer() + self.state = NIOLockedValueBox(State(data: ByteBuffer())) } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - XCTAssertFalse(self.mayReceiveData) - XCTAssertFalse(self.expectError) + self.state.withLockedValue { + XCTAssertFalse($0.mayReceiveData) + XCTAssertFalse($0.expectError) + } let promise = self.eventLoop.makePromise(of: Void.self) - promise.futureResult.whenComplete { - switch $0 { - case .success: - self.mayReceiveData = true - case .failure: - self.expectError = true + promise.futureResult.whenComplete { result in + self.state.withLockedValue { state in + switch result { + case .success: + state.mayReceiveData = true + case .failure: + state.expectError = true + } } } @@ -189,20 +208,26 @@ class DelayOnHeadDelegate: HTTPClientResponseDelegate { } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - XCTAssertTrue(self.mayReceiveData) - XCTAssertFalse(self.expectError) - self.data.writeImmutableBuffer(buffer) + self.state.withLockedValue { + XCTAssertTrue($0.mayReceiveData) + XCTAssertFalse($0.expectError) + $0.data.writeImmutableBuffer(buffer) + } return self.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Response { - XCTAssertTrue(self.mayReceiveData) - XCTAssertFalse(self.expectError) - return self.data + self.state.withLockedValue { + XCTAssertTrue($0.mayReceiveData) + XCTAssertFalse($0.expectError) + return $0.data + } } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - XCTAssertTrue(self.expectError) + self.state.withLockedValue { + XCTAssertTrue($0.expectError) + } } } @@ -262,7 +287,7 @@ enum TemporaryFileHelpers { shortEnoughPath = path restoreSavedCWD = false } catch SocketAddressError.unixDomainSocketPathTooLong { - FileManager.default.changeCurrentDirectoryPath( + _ = FileManager.default.changeCurrentDirectoryPath( URL(fileURLWithPath: path).deletingLastPathComponent().absoluteString ) shortEnoughPath = URL(fileURLWithPath: path).lastPathComponent @@ -276,7 +301,7 @@ enum TemporaryFileHelpers { try? FileManager.default.removeItem(atPath: path) } if restoreSavedCWD { - FileManager.default.changeCurrentDirectoryPath(saveCurrentDirectory) + _ = FileManager.default.changeCurrentDirectoryPath(saveCurrentDirectory) } } return try body(shortEnoughPath) @@ -334,7 +359,7 @@ enum TestTLS { ) } -internal final class HTTPBin +internal final class HTTPBin: Sendable where RequestHandler.InboundIn == HTTPServerRequestPart, RequestHandler.OutboundOut == HTTPServerResponsePart @@ -413,11 +438,15 @@ where } var port: Int { - Int(self.serverChannel.localAddress!.port!) + self.serverChannel.withLockedValue { + Int($0!.localAddress!.port!) + } } var socketAddress: SocketAddress { - self.serverChannel.localAddress! + self.serverChannel.withLockedValue { + $0!.localAddress! + } } var baseURL: String { @@ -445,9 +474,9 @@ where private let mode: Mode private let sslContext: NIOSSLContext? - private var serverChannel: Channel! + private let serverChannel = NIOLockedValueBox(nil) private let isShutdown = ManagedAtomic(false) - private let handlerFactory: (Int) -> (RequestHandler) + private let handlerFactory: @Sendable (Int) -> (RequestHandler) init( _ mode: Mode = .http1_1(ssl: false, compress: false), @@ -455,7 +484,7 @@ where bindTarget: BindTarget = .localhostIPv4RandomPort, reusePort: Bool = false, trafficShapingTargetBytesPerSecond: Int? = nil, - handlerFactory: @escaping (Int) -> (RequestHandler) + handlerFactory: @escaping @Sendable (Int) -> (RequestHandler) ) { self.mode = mode self.sslContext = HTTPBin.sslContext(for: mode) @@ -475,14 +504,14 @@ where let connectionIDAtomic = ManagedAtomic(0) - self.serverChannel = try! ServerBootstrap(group: self.group) + let serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .serverChannelOption( ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: reusePort ? 1 : 0 ) - .serverChannelInitializer { channel in - channel.pipeline.addHandler(self.activeConnCounterHandler) + .serverChannelInitializer { [activeConnCounterHandler] channel in + channel.pipeline.addHandler(activeConnCounterHandler) }.childChannelInitializer { channel in if let trafficShapingTargetBytesPerSecond = trafficShapingTargetBytesPerSecond { try! channel.pipeline.syncOperations.addHandler( @@ -526,6 +555,7 @@ where return channel.eventLoop.makeFailedFuture(error) } }.bind(to: socketAddress).wait() + self.serverChannel.withLockedValue { $0 = serverChannel } } private func syncAddHTTPProxyHandlers( @@ -544,12 +574,12 @@ where try sync.addHandler(requestDecoder) try sync.addHandler(proxySimulator) - promise.futureResult.flatMap { _ in - channel.pipeline.removeHandler(proxySimulator) + promise.futureResult.assumeIsolated().flatMap { _ in + channel.pipeline.syncOperations.removeHandler(proxySimulator) }.flatMap { _ in - channel.pipeline.removeHandler(responseEncoder) + channel.pipeline.syncOperations.removeHandler(responseEncoder) }.flatMap { _ in - channel.pipeline.removeHandler(requestDecoder) + channel.pipeline.syncOperations.removeHandler(requestDecoder) }.whenComplete { result in switch result { case .failure: @@ -653,8 +683,8 @@ where } } + try channel.pipeline.syncOperations.addHandler(sslHandler) try channel.pipeline.syncOperations.addHandler(alpnHandler) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(alpnHandler)) } func shutdown() throws { @@ -1090,13 +1120,13 @@ internal final class HTTPBinHandler: ChannelInboundHandler { ) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } - context.eventLoop.scheduleTask(in: self.delay) { + context.eventLoop.assumeIsolated().scheduleTask(in: self.delay) { guard context.channel.isActive else { context.close(promise: nil) return } - context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenComplete { result in + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenComplete { result in self.isServingRequest = false switch result { case .success: @@ -1131,7 +1161,7 @@ internal final class HTTPBinHandler: ChannelInboundHandler { } } -final class ConnectionsCountHandler: ChannelInboundHandler { +final class ConnectionsCountHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Channel private let activeConns = ManagedAtomic(0) @@ -1150,8 +1180,8 @@ final class ConnectionsCountHandler: ChannelInboundHandler { _ = self.activeConns.loadThenWrappingIncrement(ordering: .relaxed) _ = self.createdConns.loadThenWrappingIncrement(ordering: .relaxed) - channel.closeFuture.whenComplete { _ in - _ = self.activeConns.loadThenWrappingDecrement(ordering: .relaxed) + channel.closeFuture.whenComplete { [activeConns] _ in + _ = activeConns.loadThenWrappingDecrement(ordering: .relaxed) } context.fireChannelRead(data) @@ -1171,7 +1201,7 @@ internal final class CloseWithoutClosingServerHandler: ChannelInboundHandler { func handlerAdded(context: ChannelHandlerContext) { self.onClosePromise = context.eventLoop.makePromise() - self.onClosePromise!.futureResult.whenSuccess(self.callback!) + self.onClosePromise!.futureResult.assumeIsolated().whenSuccess(self.callback!) self.callback = nil } @@ -1233,7 +1263,7 @@ final class ExpectClosureServerHandler: ChannelInboundHandler { struct EventLoopFutureTimeoutError: Error {} -extension EventLoopFuture { +extension EventLoopFuture where Value: Sendable { func timeout(after failDelay: TimeAmount) -> EventLoopFuture { let promise = self.eventLoop.makePromise(of: Value.self) @@ -1259,28 +1289,27 @@ struct CollectEverythingLogHandler: LogHandler { var logLevel: Logger.Level = .info let logStore: LogStore - class LogStore { + final class LogStore: Sendable { struct Entry { var level: Logger.Level var message: String var metadata: [String: String] } - var lock = NIOLock() - var logs: [Entry] = [] + private let logs = NIOLockedValueBox<[Entry]>([]) var allEntries: [Entry] { get { - self.lock.withLock { self.logs } + self.logs.withLockedValue { $0 } } set { - self.lock.withLock { self.logs = newValue } + self.logs.withLockedValue { $0 = newValue } } } func append(level: Logger.Level, message: Logger.Message, metadata: Logger.Metadata?) { - self.lock.withLock { - self.logs.append( + self.logs.withLockedValue { + $0.append( Entry( level: level, message: message.description, @@ -1299,6 +1328,7 @@ struct CollectEverythingLogHandler: LogHandler { level: Logger.Level, message: Logger.Message, metadata: Logger.Metadata?, + source: String, file: String, function: String, line: UInt @@ -1320,10 +1350,10 @@ struct CollectEverythingLogHandler: LogHandler { /// consume the bytes by calling ``next()`` on the delegate. /// /// The sole purpose of this class is to enable straight-line stream tests. -class ResponseStreamDelegate: HTTPClientResponseDelegate { +final class ResponseStreamDelegate: HTTPClientResponseDelegate { typealias Response = Void - enum State { + enum State: Sendable { /// The delegate is in the idle state. There are no http response parts to be buffered /// and the consumer did not signal a demand. Transitions to all other states are allowed. case idle @@ -1341,10 +1371,11 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { } let eventLoop: EventLoop - private var state: State = .idle + private let state: NIOLoopBoundBox init(eventLoop: EventLoop) { self.eventLoop = eventLoop + self.state = .makeBoxSendingValue(.idle, eventLoop: eventLoop) } func next() -> EventLoopFuture { @@ -1358,25 +1389,25 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { } private func next0() -> EventLoopFuture { - switch self.state { + switch self.state.value { case .idle: let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) - self.state = .waitingForBytes(promise) + self.state.value = .waitingForBytes(promise) return promise.futureResult case .buffering(let byteBuffer, done: false): - self.state = .idle + self.state.value = .idle return self.eventLoop.makeSucceededFuture(byteBuffer) case .buffering(let byteBuffer, done: true): - self.state = .finished + self.state.value = .finished return self.eventLoop.makeSucceededFuture(byteBuffer) case .waitingForBytes: preconditionFailure("Don't call `.next` twice") case .failed(let error): - self.state = .finished + self.state.value = .finished return self.eventLoop.makeFailedFuture(error) case .finished: @@ -1406,16 +1437,16 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .buffering(buffer, done: false) + self.state.value = .buffering(buffer, done: false) case .waitingForBytes(let promise): - self.state = .idle + self.state.value = .idle promise.succeed(buffer) case .buffering(var byteBuffer, done: false): var buffer = buffer byteBuffer.writeBuffer(&buffer) - self.state = .buffering(byteBuffer, done: false) + self.state.value = .buffering(byteBuffer, done: false) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1426,14 +1457,14 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didReceiveError(task: HTTPClient.Task, _ error: Error) { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .failed(error) + self.state.value = .failed(error) case .waitingForBytes(let promise): - self.state = .finished + self.state.value = .finished promise.fail(error) case .buffering(_, done: false): - self.state = .failed(error) + self.state.value = .failed(error) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1442,14 +1473,14 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didFinishRequest(task: HTTPClient.Task) throws { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .finished + self.state.value = .finished case .waitingForBytes(let promise): - self.state = .finished + self.state.value = .finished promise.succeed(nil) case .buffering(let byteBuffer, done: false): - self.state = .buffering(byteBuffer, done: true) + self.state.value = .buffering(byteBuffer, done: true) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1471,7 +1502,7 @@ class HTTPEchoHandler: ChannelInboundHandler { case .body(let bytes): context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(bytes))), promise: nil) case .end: - context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenSuccess { + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenSuccess { context.close(promise: nil) } } @@ -1493,7 +1524,7 @@ final class HTTPEchoHeaders: ChannelInboundHandler { case .body: break case .end: - context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenSuccess { + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenSuccess { context.close(promise: nil) } } @@ -1659,7 +1690,7 @@ final class BasicInboundTrafficShapingHandler: ChannelDuplexHandler { let buffer = Self.unwrapInboundIn(data) let byteCount = buffer.readableBytes self.currentSecondBytesSeen += byteCount - context.eventLoop.scheduleTask(in: .seconds(1)) { + context.eventLoop.assumeIsolated().scheduleTask(in: .seconds(1)) { self.currentSecondBytesSeen -= byteCount self.evaluatePause(context: loopBoundContext.value) } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index fbd40ce3a..50c3ecb9d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -46,7 +46,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let request3 = try Request(url: "unix:///tmp/file") XCTAssertEqual(request3.host, "") - #if os(Linux) && compiler(>=6.0) + #if os(Linux) && compiler(>=6.0) && compiler(<6.1) XCTAssertEqual(request3.url.host, "") #else XCTAssertNil(request3.url.host) @@ -314,10 +314,9 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testPostWithGenericBody() throws { - let bodyData = Array("hello, world!").lazy.map { $0.uppercased().first!.asciiValue! } - let erasedData = AnyRandomAccessCollection(bodyData) + let bodyData = Array(Array("hello, world!").lazy.map { $0.uppercased().first!.asciiValue! }) - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .bytes(erasedData)) + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .bytes(bodyData)) .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -469,6 +468,14 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/302").wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, self.defaultHTTPBinURLPrefix + "ok") + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [ + self.defaultHTTPBinURLPrefix + "redirect/302", + self.defaultHTTPBinURLPrefix + "ok", + ] + ) response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/https?port=\(httpsBin.port)") .wait() @@ -501,6 +508,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .found) XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) request = try Request( url: "https://localhost/(httpsBin.port)/redirect/target", @@ -512,6 +521,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .found) XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) // From HTTP or HTTPS to HTTPS+UNIX should also fail to redirect targetURL = @@ -526,6 +537,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .found) XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) request = try Request( url: "https://localhost/(httpsBin.port)/redirect/target", @@ -537,6 +550,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .found) XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) // ... while HTTP+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed targetURL = self.defaultHTTPBinURLPrefix + "ok" @@ -550,6 +565,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) targetURL = "https://localhost/(httpsBin.port)/ok" request = try Request( @@ -562,6 +582,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" @@ -575,6 +600,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" @@ -588,6 +618,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) // ... and HTTPS+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed targetURL = self.defaultHTTPBinURLPrefix + "ok" @@ -601,6 +636,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) targetURL = "https://localhost/(httpsBin.port)/ok" request = try Request( @@ -613,6 +653,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" @@ -626,6 +671,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" @@ -639,6 +689,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { response = try localClient.execute(request: request).wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) } ) } @@ -729,11 +784,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var request = try Request(url: self.defaultHTTPBinURLPrefix + "events/10/content-length") request.headers.add(name: "Accept", value: "text/event-stream") - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in let delegate = try FileDownloadDelegate(path: path) - let progress = try self.defaultClient.execute( + let response = try self.defaultClient.execute( request: request, delegate: delegate ) @@ -741,19 +796,22 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { try XCTAssertEqual(50, TemporaryFileHelpers.fileSize(path: path)) - return progress + return response } - XCTAssertEqual(50, progress.totalBytes) - XCTAssertEqual(50, progress.receivedBytes) + XCTAssertEqual(.ok, response.head.status) + XCTAssertEqual("50", response.head.headers.first(name: "content-length")) + + XCTAssertEqual(50, response.totalBytes) + XCTAssertEqual(50, response.receivedBytes) } func testFileDownloadError() throws { var request = try Request(url: self.defaultHTTPBinURLPrefix + "not-found") request.headers.add(name: "Accept", value: "text/event-stream") - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in let delegate = try FileDownloadDelegate( path: path, reportHead: { @@ -761,7 +819,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } ) - let progress = try self.defaultClient.execute( + let response = try self.defaultClient.execute( request: request, delegate: delegate ) @@ -769,11 +827,14 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertFalse(TemporaryFileHelpers.fileExists(path: path)) - return progress + return response } - XCTAssertEqual(nil, progress.totalBytes) - XCTAssertEqual(0, progress.receivedBytes) + XCTAssertEqual(.notFound, response.head.status) + XCTAssertFalse(response.head.headers.contains(name: "content-length")) + + XCTAssertEqual(nil, response.totalBytes) + XCTAssertEqual(0, response.receivedBytes) } func testFileDownloadCustomError() throws { @@ -845,8 +906,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { body: .stream { streamWriter in _ = streamWriter.write(.byteBuffer(.init())) - let promise = self.clientGroup.next().makePromise(of: Void.self) - self.clientGroup.next().scheduleTask(in: .milliseconds(3)) { + let promise = localClient.eventLoopGroup.next().makePromise(of: Void.self) + localClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(3)) { streamWriter.write(.byteBuffer(.init())).cascade(to: promise) } @@ -1062,23 +1123,23 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try localClient.syncShutdown()) } - class EventLoopValidatingDelegate: HTTPClientResponseDelegate { + final class EventLoopValidatingDelegate: HTTPClientResponseDelegate { typealias Response = Bool let eventLoop: EventLoop - var result = false + let result = NIOLockedValueBox(false) init(eventLoop: EventLoop) { self.eventLoop = eventLoop } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.result = task.eventLoop === self.eventLoop + self.result.withLockedValue { $0 = task.eventLoop === self.eventLoop } return task.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Bool { - self.result + self.result.withLockedValue { $0 } } } @@ -1286,7 +1347,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let numberOfRequestsPerThread = 1000 let numberOfParallelWorkers = 5 - final class HTTPServer: ChannelInboundHandler { + final class HTTPServer: ChannelInboundHandler, Sendable { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart @@ -1332,10 +1393,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let url = "http://127.0.0.1/(server?.localAddress?.port%20??%20-1)/hello" let g = DispatchGroup() + let defaultClient = self.defaultClient! for workerID in 0.. Channel? { try? ServerBootstrap(group: group) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler( + channel.pipeline.configureHTTPServerPipeline().flatMapThrowing { + try channel.pipeline.syncOperations.addHandler( HTTPServer( headPromise: headPromise, bodyPromises: bodyPromises, @@ -2508,11 +2576,12 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testUploadStreamingCallinToleratedFromOtsideEL() throws { + let defaultClient = self.defaultClient! let request = try HTTPClient.Request( url: self.defaultHTTPBinURLPrefix + "get", method: .POST, body: .stream(contentLength: 4) { writer in - let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + let promise = defaultClient.eventLoopGroup.next().makePromise(of: Void.self) // We have to toleare callins from any thread DispatchQueue(label: "upload-streaming").async { writer.write(.byteBuffer(ByteBuffer(string: "1234"))).whenComplete { _ in @@ -3216,12 +3285,12 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testConnectErrorPropagatedToDelegate() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void - var error: Error? + let error = NIOLockedValueBox(nil) func didFinishRequest(task: HTTPClient.Task) throws {} func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.error = error + self.error.withLockedValue { $0 = error } } } @@ -3240,12 +3309,12 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertThrowsError(try httpClient.execute(request: request, delegate: delegate).wait()) { XCTAssertEqualTypeAndValue($0, HTTPClientError.connectTimeout) - XCTAssertEqualTypeAndValue(delegate.error, HTTPClientError.connectTimeout) + XCTAssertEqualTypeAndValue(delegate.error.withLockedValue { $0 }, HTTPClientError.connectTimeout) } } func testDelegateCallinsTolerateRandomEL() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void let eventLoop: EventLoop @@ -3291,15 +3360,50 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try future.wait()) } + func testDelegateGetsErrorsFromCreatingRequestBag() throws { + // We want to test that we propagate errors to the delegate from failures to construct the + // request bag. Those errors only come from invalid headers. + final class TestDelegate: HTTPClientResponseDelegate, Sendable { + typealias Response = Void + let error: NIOLockedValueBox = .init(nil) + func didFinishRequest(task: HTTPClient.Task) throws {} + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.error.withLockedValue { $0 = error } + } + } + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup) + ) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + + // 198.51.100.254 is reserved for documentation only + var request = try HTTPClient.Request(url: "http://198.51.100.254:65535/get") + request.headers.replaceOrAdd(name: "Not-ASCII", value: "not-fine\n") + let delegate = TestDelegate() + + XCTAssertThrowsError(try httpClient.execute(request: request, delegate: delegate).wait()) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.invalidHeaderFieldValues(["not-fine\n"])) + XCTAssertEqualTypeAndValue( + delegate.error.withLockedValue { $0 }, + HTTPClientError.invalidHeaderFieldValues(["not-fine\n"]) + ) + } + } + func testContentLengthTooLongFails() throws { let url = self.defaultHTTPBinURLPrefix + "post" + let defaultClient = self.defaultClient! XCTAssertThrowsError( try self.defaultClient.execute( request: Request( url: url, body: .stream(contentLength: 10) { streamWriter in - let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + let promise = defaultClient.eventLoopGroup.next().makePromise(of: Void.self) DispatchQueue(label: "content-length-test").async { streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise) } @@ -3395,6 +3499,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // second connection. _ = self.defaultClient.get(url: "http://localhost/(self.defaultHTTPBin.port)/events/10/1") + let clientGroup = self.clientGroup! var request = try HTTPClient.Request(url: "http://localhost/(self.defaultHTTPBin.port)/wait", method: .POST) request.body = .stream { writer in // Start writing chunks so tha we will try to write after read timeout is thrown @@ -3402,8 +3507,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { _ = writer.write(.byteBuffer(ByteBuffer(string: "1234"))) } - let promise = self.clientGroup.next().makePromise(of: Void.self) - self.clientGroup.next().scheduleTask(in: .milliseconds(3)) { + let promise = clientGroup.next().makePromise(of: Void.self) + clientGroup.next().scheduleTask(in: .milliseconds(3)) { writer.write(.byteBuffer(ByteBuffer(string: "1234"))).cascade(to: promise) } @@ -3418,7 +3523,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testSSLHandshakeErrorPropagation() throws { - class CloseHandler: ChannelInboundHandler { + final class CloseHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Any func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -3475,11 +3580,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testSSLHandshakeErrorPropagationDelayedClose() throws { // This is as the test above, but the close handler delays its close action by a few hundred ms. // This will tend to catch the pipeline at different weird stages, and flush out different bugs. - class CloseHandler: ChannelInboundHandler { + final class CloseHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Any func channelRead(context: ChannelHandlerContext, data: NIOAny) { - context.eventLoop.scheduleTask(in: .milliseconds(100)) { + context.eventLoop.assumeIsolated().scheduleTask(in: .milliseconds(100)) { context.close(promise: nil) } } @@ -3536,8 +3641,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let server = try ServerBootstrap(group: self.serverGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler(CloseWithoutClosingServerHandler(group.leave)) + channel.pipeline.configureHTTPServerPipeline().flatMapThrowing { + try channel.pipeline.syncOperations.addHandler(CloseWithoutClosingServerHandler(group.leave)) } } .bind(host: "localhost", port: 0) @@ -3906,11 +4011,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var request = try Request(url: self.defaultHTTPBinURLPrefix + "chunked") request.headers.add(name: "Accept", value: "text/event-stream") - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in let delegate = try FileDownloadDelegate(path: path) - let progress = try self.defaultClient.execute( + let response = try self.defaultClient.execute( request: request, delegate: delegate ) @@ -3918,11 +4023,15 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { try XCTAssertEqual(50, TemporaryFileHelpers.fileSize(path: path)) - return progress + return response } - XCTAssertEqual(nil, progress.totalBytes) - XCTAssertEqual(50, progress.receivedBytes) + XCTAssertEqual(.ok, response.head.status) + XCTAssertEqual("chunked", response.head.headers.first(name: "transfer-encoding")) + XCTAssertFalse(response.head.headers.contains(name: "content-length")) + + XCTAssertEqual(nil, response.totalBytes) + XCTAssertEqual(50, response.receivedBytes) } func testCloseWhileBackpressureIsExertedIsFine() throws { @@ -4118,12 +4227,76 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try client.execute(request: request).wait()) } + func testCancelingRequestAfterRedirect() throws { + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": self.defaultHTTPBinURLPrefix + "wait"], + body: nil + ) + + final class CancelAfterRedirect: HTTPClientResponseDelegate, Sendable { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + } + + let task = defaultClient.execute( + request: request, + delegate: CancelAfterRedirect(), + deadline: .now() + .seconds(1) + ) + + // there is currently no HTTPClientResponseDelegate method to ensure the redirect occurs before we cancel, so we just sleep for 500ms + Thread.sleep(forTimeInterval: 0.5) + + task.cancel() + + XCTAssertThrowsError(try task.wait()) { error in + guard case let error = error as? HTTPClientError, error == .cancelled else { + return XCTFail("Should fail with cancelled") + } + } + } + + func testFailingRequestAfterRedirect() throws { + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": self.defaultHTTPBinURLPrefix + "wait"], + body: nil + ) + + final class FailAfterRedirect: HTTPClientResponseDelegate, Sendable { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + } + + let task = defaultClient.execute( + request: request, + delegate: FailAfterRedirect(), + deadline: .now() + .seconds(1) + ) + + // there is currently no HTTPClientResponseDelegate method to ensure the redirect occurs before we fail, so we just sleep for 500ms + Thread.sleep(forTimeInterval: 0.5) + + struct TestError: Error {} + + task.fail(reason: TestError()) + + XCTAssertThrowsError(try task.wait()) { error in + guard error is TestError else { + return XCTFail("Should fail with TestError") + } + } + } + func testCancelingHTTP1RequestAfterHeaderSend() throws { var request = try HTTPClient.Request(url: self.defaultHTTPBin.baseURL + "/wait", method: .POST) // non-empty body is important request.body = .byteBuffer(ByteBuffer([1])) - class CancelAfterHeadSend: HTTPClientResponseDelegate { + final class CancelAfterHeadSend: HTTPClientResponseDelegate, Sendable { init() {} func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { @@ -4140,7 +4313,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // non-empty body is important request.body = .byteBuffer(ByteBuffer([1])) - class CancelAfterHeadSend: HTTPClientResponseDelegate { + final class CancelAfterHeadSend: HTTPClientResponseDelegate, Sendable { init() {} func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { @@ -4302,4 +4475,174 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { request.setBasicAuth(username: "foo", password: "bar") XCTAssertEqual(request.headers.first(name: "Authorization"), "Basic Zm9vOmJhcg==") } + + func runBaseTestForHTTP1ConnectionDebugInitializer(ssl: Bool) { + let connectionDebugInitializerUtil = CountingDebugInitializerUtil() + + // Initializing even with just `http1_1ConnectionDebugInitializer` (rather than manually + // modifying `config`) to ensure that the matching `init` actually wires up this argument + // with the respective property. This is necessary as these parameters are defaulted and can + // be easy to miss. + var config = HTTPClient.Configuration( + http1_1ConnectionDebugInitializer: { channel in + connectionDebugInitializerUtil.initialize(channel: channel) + } + ) + config.httpVersion = .http1Only + + if ssl { + config.tlsConfiguration = .clientDefault + config.tlsConfiguration?.certificateVerification = .none + } + + let higherConnectTimeout = CountingDebugInitializerUtil.duration + .milliseconds(100) + var configWithHigherTimeout = config + configWithHigherTimeout.timeout = .init(connect: higherConnectTimeout) + + let clientWithHigherTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithHigherTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithHigherTimeout.syncShutdown()) } + + let bin = HTTPBin(.http1_1(ssl: ssl, compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let scheme = ssl ? "https" : "http" + + for _ in 0..<3 { + XCTAssertNoThrow( + try clientWithHigherTimeout.get(url: "\(scheme)://localhost:\(bin.port)/get").wait() + ) + } + + // Even though multiple requests were made, the connection debug initializer must be called + // only once. + XCTAssertEqual(connectionDebugInitializerUtil.executionCount, 1) + + let lowerConnectTimeout = CountingDebugInitializerUtil.duration - .milliseconds(100) + var configWithLowerTimeout = config + configWithLowerTimeout.timeout = .init(connect: lowerConnectTimeout) + + let clientWithLowerTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithLowerTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithLowerTimeout.syncShutdown()) } + + XCTAssertThrowsError( + try clientWithLowerTimeout.get(url: "\(scheme)://localhost:\(bin.port)/get").wait() + ) { + XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + } + } + + func testHTTP1PlainTextConnectionDebugInitializer() { + runBaseTestForHTTP1ConnectionDebugInitializer(ssl: false) + } + + func testHTTP1EncryptedConnectionDebugInitializer() { + runBaseTestForHTTP1ConnectionDebugInitializer(ssl: true) + } + + func testHTTP2ConnectionAndStreamChannelDebugInitializers() { + let connectionDebugInitializerUtil = CountingDebugInitializerUtil() + let streamChannelDebugInitializerUtil = CountingDebugInitializerUtil() + + // Initializing even with just `http2ConnectionDebugInitializer` and + // `http2StreamChannelDebugInitializer` (rather than manually modifying `config`) to ensure + // that the matching `init` actually wires up these arguments with the respective + // properties. This is necessary as these parameters are defaulted and can be easy to miss. + var config = HTTPClient.Configuration( + http2ConnectionDebugInitializer: { channel in + connectionDebugInitializerUtil.initialize(channel: channel) + }, + http2StreamChannelDebugInitializer: { channel in + streamChannelDebugInitializerUtil.initialize(channel: channel) + } + ) + config.tlsConfiguration = .clientDefault + config.tlsConfiguration?.certificateVerification = .none + config.httpVersion = .automatic + + let higherConnectTimeout = CountingDebugInitializerUtil.duration + .milliseconds(100) + var configWithHigherTimeout = config + configWithHigherTimeout.timeout = .init(connect: higherConnectTimeout) + + let clientWithHigherTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithHigherTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithHigherTimeout.syncShutdown()) } + + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let numberOfRequests = 3 + + for _ in 0..(0) + var executionCount: Int { self._executionCount.withLockedValue { $0 } } + + /// The minimum time to spend running the debug initializer. + static let duration: TimeAmount = .milliseconds(300) + + /// The actual debug initializer. + func initialize(channel: Channel) -> EventLoopFuture { + self._executionCount.withLockedValue { $0 += 1 } + + let someScheduledTask = channel.eventLoop.scheduleTask(in: Self.duration) { + channel.eventLoop.makeSucceededVoidFuture() + } + + return someScheduledTask.futureResult.flatMap { $0 } + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift index d9dbd4cb1..15cc9e7e9 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift @@ -184,7 +184,7 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { } } -class NeverrespondServerHandler: ChannelInboundHandler { +final class NeverrespondServerHandler: ChannelInboundHandler, Sendable { typealias InboundIn = NIOAny func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -194,11 +194,11 @@ class NeverrespondServerHandler: ChannelInboundHandler { /// A `HTTPConnectionRequester` that will fail a test if any of its methods are ever called. final class ExplodingRequester: HTTPConnectionRequester { - func http1ConnectionCreated(_: HTTP1Connection) { + func http1ConnectionCreated(_: HTTP1Connection.SendableView) { XCTFail("http1ConnectionCreated called unexpectedly") } - func http2ConnectionCreated(_: HTTP2Connection, maximumStreams: Int) { + func http2ConnectionCreated(_: HTTP2Connection.SendableView, maximumStreams: Int) { XCTFail("http2ConnectionCreated called unexpectedly") } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift index d792895d3..4f4bbd785 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift @@ -83,7 +83,7 @@ class HTTPConnectionPool_RequestQueueTests: XCTestCase { } } -private class MockScheduledRequest: HTTPSchedulableRequest { +final private class MockScheduledRequest: HTTPSchedulableRequest { let requiredEventLoop: EventLoop? init(requiredEventLoop: EventLoop?) { diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift index 021c69731..67f18cbb8 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import Logging +import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import XCTest @@ -20,7 +21,7 @@ import XCTest @testable import AsyncHTTPClient final class MockHTTPExecutableRequest: HTTPExecutableRequest { - enum Event { + enum Event: Sendable { /// ``Event`` without associated values enum Kind: Hashable { case willExecuteRequest @@ -56,39 +57,49 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { } } - var logger: Logging.Logger = Logger(label: "request") - var requestHead: NIOHTTP1.HTTPRequestHead - var requestFramingMetadata: RequestFramingMetadata - var requestOptions: RequestOptions = .forTests() + let logger: Logging.Logger = Logger(label: "request") + let requestHead: NIOHTTP1.HTTPRequestHead + let requestFramingMetadata: RequestFramingMetadata + let requestOptions: RequestOptions = .forTests() /// if true and ``HTTPExecutableRequest`` method is called without setting a corresponding callback on `self` e.g. /// If ``HTTPExecutableRequest\.willExecuteRequest(_:)`` is called but ``willExecuteRequestCallback`` is not set, /// ``XCTestFail(_:)`` will be called to fail the current test. - var raiseErrorIfUnimplementedMethodIsCalled: Bool = true - private var file: StaticString - private var line: UInt - - var willExecuteRequestCallback: ((HTTPRequestExecutor) -> Void)? - var requestHeadSentCallback: (() -> Void)? - var resumeRequestBodyStreamCallback: (() -> Void)? - var pauseRequestBodyStreamCallback: (() -> Void)? - var receiveResponseHeadCallback: ((HTTPResponseHead) -> Void)? - var receiveResponseBodyPartsCallback: ((CircularBuffer) -> Void)? - var succeedRequestCallback: ((CircularBuffer?) -> Void)? - var failCallback: ((Error) -> Void)? + let raiseErrorIfUnimplementedMethodIsCalled: Bool + private let file: StaticString + private let line: UInt + + let willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? = nil + let requestHeadSentCallback: (@Sendable () -> Void)? = nil + let resumeRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + let pauseRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + let receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? = nil + let receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? = nil + let succeedRequestCallback: (@Sendable (CircularBuffer?) -> Void)? = nil + let failCallback: (@Sendable (Error) -> Void)? = nil /// captures all ``HTTPExecutableRequest`` method calls in the order of occurrence, including arguments. /// If you are not interested in the arguments you can use `events.map(\.kind)` to get all events without arguments. - private(set) var events: [Event] = [] + private let _events = NIOLockedValueBox<[Event]>([]) + private(set) var events: [Event] { + get { + self._events.withLockedValue { $0 } + } + set { + self._events.withLockedValue { $0 = newValue } + } + } init( head: NIOHTTP1.HTTPRequestHead = .init(version: .http1_1, method: .GET, uri: "http://localhost/"), framingMetadata: RequestFramingMetadata = .init(connectionClose: false, body: .fixedSize(0)), + raiseErrorIfUnimplementedMethodIsCalled: Bool = true, file: StaticString = #file, line: UInt = #line ) { self.requestHead = head self.requestFramingMetadata = framingMetadata + self.raiseErrorIfUnimplementedMethodIsCalled = raiseErrorIfUnimplementedMethodIsCalled self.file = file self.line = line } diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift index f85c75ce5..e5d9caa8e 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift @@ -25,7 +25,7 @@ final class MockRequestExecutor { case unexpectedByteBuffer } - enum RequestParts: Equatable { + enum RequestParts: Equatable, Sendable { case body(IOData) case endOfStream @@ -58,10 +58,15 @@ final class MockRequestExecutor { private let responseBodyDemandLock = ConditionLock(value: false) private let cancellationLock = ConditionLock(value: false) - private var request: HTTPExecutableRequest? - private var _signaledDemandForRequestBody: Bool = false + private struct State: Sendable { + var request: HTTPExecutableRequest? + var _signaledDemandForRequestBody: Bool = false + } + + private let state: NIOLockedValueBox init(pauseRequestBodyPartStreamAfterASingleWrite: Bool = false, eventLoop: EventLoop) { + self.state = NIOLockedValueBox(State()) self.pauseRequestBodyPartStreamAfterASingleWrite = pauseRequestBodyPartStreamAfterASingleWrite self.eventLoop = eventLoop } @@ -77,8 +82,10 @@ final class MockRequestExecutor { } private func runRequest0(_ request: HTTPExecutableRequest) { - precondition(self.request == nil) - self.request = request + self.state.withLockedValue { + precondition($0.request == nil) + $0.request = request + } request.willExecuteRequest(self) request.requestHeadSent() } @@ -127,10 +134,16 @@ final class MockRequestExecutor { } private func pauseRequestBodyStream0() { - if self._signaledDemandForRequestBody == true { - self._signaledDemandForRequestBody = false - self.request!.pauseRequestBodyStream() + let request = self.state.withLockedValue { + if $0._signaledDemandForRequestBody == true { + $0._signaledDemandForRequestBody = false + return $0.request + } else { + return nil + } } + + request?.pauseRequestBodyStream() } func resumeRequestBodyStream() { @@ -144,10 +157,16 @@ final class MockRequestExecutor { } private func resumeRequestBodyStream0() { - if self._signaledDemandForRequestBody == false { - self._signaledDemandForRequestBody = true - self.request!.resumeRequestBodyStream() + let request = self.state.withLockedValue { + if $0._signaledDemandForRequestBody == false { + $0._signaledDemandForRequestBody = true + return $0.request + } else { + return nil + } } + + request?.resumeRequestBodyStream() } func resetResponseStreamDemandSignal() { @@ -204,11 +223,13 @@ extension MockRequestExecutor: HTTPRequestExecutor { case none } - let stateChange = { () -> WriteAction in + let stateChange = { @Sendable () -> WriteAction in var pause = false if self.blockingQueue.isEmpty && self.pauseRequestBodyPartStreamAfterASingleWrite && part.isBody { pause = true - self._signaledDemandForRequestBody = false + self.state.withLockedValue { + $0._signaledDemandForRequestBody = false + } } self.blockingQueue.append(.success(part)) @@ -283,3 +304,5 @@ extension MockRequestExecutor { } } } + +extension MockRequestExecutor.BlockingQueue: @unchecked Sendable where Element: Sendable {} diff --git a/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift b/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift index 033214ffe..63eaf649d 100644 --- a/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift @@ -16,6 +16,7 @@ @testable import AsyncHTTPClient import Network import NIOCore +import NIOConcurrencyHelpers import NIOEmbedded import NIOSSL import NIOTransportServices @@ -23,21 +24,41 @@ import XCTest @available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) class NWWaitingHandlerTests: XCTestCase { - class MockRequester: HTTPConnectionRequester { - var waitingForConnectivityCalled = false - var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? - var transientError: NWError? + final class MockRequester: HTTPConnectionRequester { + private struct State: Sendable { + var waitingForConnectivityCalled = false + var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? + var transientError: NWError? + } + + private let state = NIOLockedValueBox(State()) + + var waitingForConnectivityCalled: Bool { + self.state.withLockedValue { $0.waitingForConnectivityCalled } + } + + var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? { + self.state.withLockedValue { $0.connectionID } + } + + var transientError: NWError? { + self.state.withLockedValue { + $0.transientError + } + } - func http1ConnectionCreated(_: AsyncHTTPClient.HTTP1Connection) {} + func http1ConnectionCreated(_: AsyncHTTPClient.HTTP1Connection.SendableView) {} - func http2ConnectionCreated(_: AsyncHTTPClient.HTTP2Connection, maximumStreams: Int) {} + func http2ConnectionCreated(_: AsyncHTTPClient.HTTP2Connection.SendableView, maximumStreams: Int) {} func failedToCreateHTTPConnection(_: AsyncHTTPClient.HTTPConnectionPool.Connection.ID, error: Error) {} func waitingForConnectivity(_ connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID, error: Error) { - self.waitingForConnectivityCalled = true - self.connectionID = connectionID - self.transientError = error as? NWError + self.state.withLockedValue { + $0.waitingForConnectivityCalled = true + $0.connectionID = connectionID + $0.transientError = error as? NWError + } } } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index 9aa595224..2b0c2f6e4 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -127,6 +127,7 @@ final class RequestBagTests: XCTestCase { XCTAssertNoThrow(try executor.receiveEndOfStream()) XCTAssertEqual(receivedBytes, bytesToSent, "We have sent all request bytes...") + XCTAssertTrue(delegate.history.isEmpty) XCTAssertNil(delegate.receivedHead, "Expected not to have a response head, before `receiveResponseHead`") let responseHead = HTTPResponseHead( version: .http1_1, @@ -140,6 +141,10 @@ final class RequestBagTests: XCTestCase { XCTAssertEqual(responseHead, delegate.receivedHead) XCTAssertNoThrow(try XCTUnwrap(delegate.backpressurePromise).succeed(())) XCTAssertTrue(executor.signalledDemandForResponseBody) + + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) + executor.resetResponseStreamDemandSignal() // we will receive 20 chunks with each 10 byteBuffers and 32 bytes @@ -747,13 +752,15 @@ final class RequestBagTests: XCTestCase { let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead( - .init( - version: .http1_1, - status: .permanentRedirect, - headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] - ) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertTrue(executor.signalledDemandForResponseBody) executor.resetResponseStreamDemandSignal() @@ -833,13 +840,15 @@ final class RequestBagTests: XCTestCase { let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead( - .init( - version: .http1_1, - status: .permanentRedirect, - headers: ["content-length": "\(4 * 1024)", "location": "https://swift.org/sswg"] - ) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(4 * 1024)", "location": "https://swift.org/sswg"] ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertFalse(executor.signalledDemandForResponseBody) XCTAssertTrue(executor.isCancelled) @@ -893,13 +902,15 @@ final class RequestBagTests: XCTestCase { let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead( - .init( - version: .http1_1, - status: .permanentRedirect, - headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] - ) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertTrue(executor.signalledDemandForResponseBody) executor.resetResponseStreamDemandSignal() @@ -928,7 +939,7 @@ final class RequestBagTests: XCTestCase { } func testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise() { - final class LeakDetector {} + final class LeakDetector: Sendable {} let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } @@ -989,65 +1000,90 @@ extension HTTPClient.Task { } } -class UploadCountingDelegate: HTTPClientResponseDelegate { +final class UploadCountingDelegate: HTTPClientResponseDelegate { typealias Response = Void let eventLoop: EventLoop - private(set) var hitDidSendRequestHead = 0 - private(set) var hitDidSendRequestPart = 0 - private(set) var hitDidSendRequest = 0 - private(set) var hitDidReceiveResponse = 0 - private(set) var hitDidReceiveBodyPart = 0 - private(set) var hitDidReceiveError = 0 + struct State: Sendable { + var hitDidSendRequestHead = 0 + var hitDidSendRequestPart = 0 + var hitDidSendRequest = 0 + var hitDidReceiveResponse = 0 + var hitDidReceiveBodyPart = 0 + var hitDidReceiveError = 0 + + var history: [(request: HTTPClient.Request, response: HTTPResponseHead)] = [] + var receivedHead: HTTPResponseHead? + var lastBodyPart: ByteBuffer? + var backpressurePromise: EventLoopPromise? + var lastError: Error? + } + + private let state: NIOLoopBoundBox + + var hitDidSendRequestHead: Int { self.state.value.hitDidSendRequestHead } + var hitDidSendRequestPart: Int { self.state.value.hitDidSendRequestPart } + var hitDidSendRequest: Int { self.state.value.hitDidSendRequest } + var hitDidReceiveResponse: Int { self.state.value.hitDidReceiveResponse } + var hitDidReceiveBodyPart: Int { self.state.value.hitDidReceiveBodyPart } + var hitDidReceiveError: Int { self.state.value.hitDidReceiveError } - private(set) var receivedHead: HTTPResponseHead? - private(set) var lastBodyPart: ByteBuffer? - private(set) var backpressurePromise: EventLoopPromise? - private(set) var lastError: Error? + var history: [(request: HTTPClient.Request, response: HTTPResponseHead)] { + self.state.value.history + } + var receivedHead: HTTPResponseHead? { self.state.value.receivedHead } + var lastBodyPart: ByteBuffer? { self.state.value.lastBodyPart } + var backpressurePromise: EventLoopPromise? { self.state.value.backpressurePromise } + var lastError: Error? { self.state.value.lastError } init(eventLoop: EventLoop) { self.eventLoop = eventLoop + self.state = .makeBoxSendingValue(State(), eventLoop: eventLoop) } func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { - self.hitDidSendRequestHead += 1 + self.state.value.hitDidSendRequestHead += 1 } func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { - self.hitDidSendRequestPart += 1 + self.state.value.hitDidSendRequestPart += 1 } func didSendRequest(task: HTTPClient.Task) { - self.hitDidSendRequest += 1 + self.state.value.hitDidSendRequest += 1 + } + + func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) { + self.state.value.history.append((request, head)) } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.receivedHead = head + self.state.value.receivedHead = head return self.createBackpressurePromise() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - assert(self.backpressurePromise == nil) - self.hitDidReceiveBodyPart += 1 - self.lastBodyPart = buffer + assert(self.state.value.backpressurePromise == nil) + self.state.value.hitDidReceiveBodyPart += 1 + self.state.value.lastBodyPart = buffer return self.createBackpressurePromise() } func didFinishRequest(task: HTTPClient.Task) throws { - self.hitDidReceiveResponse += 1 + self.state.value.hitDidReceiveResponse += 1 } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.hitDidReceiveError += 1 - self.lastError = error + self.state.value.hitDidReceiveError += 1 + self.state.value.lastError = error } private func createBackpressurePromise() -> EventLoopFuture { - assert(self.backpressurePromise == nil) - self.backpressurePromise = self.eventLoop.makePromise(of: Void.self) - return self.backpressurePromise!.futureResult.always { _ in - self.backpressurePromise = nil + assert(self.state.value.backpressurePromise == nil) + self.state.value.backpressurePromise = self.eventLoop.makePromise(of: Void.self) + return self.state.value.backpressurePromise!.futureResult.always { _ in + self.state.value.backpressurePromise = nil } } } diff --git a/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem b/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem index f6314d47a..f16590cde 100644 --- a/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem +++ b/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem @@ -1,12 +1,12 @@ -----BEGIN CERTIFICATE----- -MIIBxDCCAUmgAwIBAgIVAPY31L1kyEnjO1E4inpE7+SYRO9mMAoGCCqGSM49BAMD -MCoxFDASBgNVBAoMC1NlbGYgU2lnbmVkMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcN -MjQwMzI4MjI0MDUyWhcNMjUwMzI4MjI0MDUyWjAqMRQwEgYDVQQKDAtTZWxmIFNp -Z25lZDESMBAGA1UEAwwJbG9jYWxob3N0MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE -o2i+uiLtMu0Jzsk3oEUnfoM9n44/aV9UeOXxyDs57i2E13HrJeWIXACetybkB+Q8 -Poab6ohbskTwrS7WN3tFgoGdRBCKQow/rTECdezR/fdz2cGADaBN+CNMuFSnFSr5 -oy8wLTAWBgNVHREEDzANggtleGFtcGxlLmNvbTATBgNVHSUEDDAKBggrBgEFBQcD -ATAKBggqhkjOPQQDAwNpADBmAjEAwF5OlUBOloDTIAxgaSSvHBMSVOE1rY5hUlkT -kQ+dQFeUe3Fn+Er5ohvkt+qVOQ5yAjEAt9s5b/Iz+JmWxKKUyExHob6QHEuuHmJy -AKdrn20Ply60bb8qxGYHhwhoyV2MZYVV +MIIBwTCCAUigAwIBAgIUX7f9BABxGdAqG5EvLpQScFt9lOkwCgYIKoZIzj0EAwMw +KjEUMBIGA1UECgwLU2VsZiBTaWduZWQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0y +NTA0MDExNDMwMTFaFw0yNjA0MDExNDMwMTFaMCoxFDASBgNVBAoMC1NlbGYgU2ln +bmVkMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQAIgNiAAQW +szfO5HCWIWgKUqyXUU0pFpYgaq01RRL69XZz1CkV6XTrxMfIvvwez2886EQDL8QX +i5NpKg3qvPgWuDjVHaj4WEJe5XMNqcujxcTufBlmaQ6o4vtoK7CIHDIDldF/HRij +LzAtMBYGA1UdEQQPMA2CC2V4YW1wbGUuY29tMBMGA1UdJQQMMAoGCCsGAQUFBwMB +MAoGCCqGSM49BAMDA2cAMGQCMBJ8Dxg0qX2bEZ3r6dI3UCGAUYxJDVk+XhiIY1Fm +5FJeQqhaVayCRPrPXXGZUJGY/wIwXej70FwkxHKLq+XxfHTC5CzmoOK469C9Rk9Y +ucddXM83ebFxVNgRCWetH9tDdXJ9 -----END CERTIFICATE----- \ No newline at end of file diff --git a/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem b/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem index 7cf27cc35..3ad9ce79e 100644 --- a/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem +++ b/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem @@ -1,6 +1,6 @@ -----BEGIN PRIVATE KEY----- -MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDhC5OSjPQeYRm4irIH -z4EyM/NbJsX39SlI6J4/q0Syt0BwojgJKhCWfeveanbIjbWhZANiAASjaL66Iu0y -7QnOyTegRSd+gz2fjj9pX1R45fHIOznuLYTXcesl5YhcAJ63JuQH5Dw+hpvqiFuy -RPCtLtY3e0WCgZ1EEIpCjD+tMQJ17NH993PZwYANoE34I0y4VKcVKvk= +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDD9v51MTOcgFIbiHbok +U+QOubosGF1u1q+D3fEUb1U2cgjCofKmPHekXTz0xu9MJi2hZANiAAQWszfO5HCW +IWgKUqyXUU0pFpYgaq01RRL69XZz1CkV6XTrxMfIvvwez2886EQDL8QXi5NpKg3q +vPgWuDjVHaj4WEJe5XMNqcujxcTufBlmaQ6o4vtoK7CIHDIDldF/HRg= -----END PRIVATE KEY----- \ No newline at end of file diff --git a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift index 1170aa444..2352c6c1c 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift @@ -38,7 +38,7 @@ class SOCKSEventsHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [socksEventsHandler]) XCTAssertNotNil(socksEventsHandler.socksEstablishedFuture) - XCTAssertNoThrow(try embedded.pipeline.removeHandler(socksEventsHandler).wait()) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.removeHandler(socksEventsHandler).wait()) XCTAssertThrowsError(try XCTUnwrap(socksEventsHandler.socksEstablishedFuture).wait()) } diff --git a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift index 6dda7d928..50d26b278 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift @@ -53,23 +53,27 @@ class MockSOCKSServer { bootstrap = ServerBootstrap(group: elg) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in - channel.pipeline.addHandler(TestSOCKSBadServerHandler()) + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(TestSOCKSBadServerHandler()) + } } } else { bootstrap = ServerBootstrap(group: elg) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in - let handshakeHandler = SOCKSServerHandshakeHandler() - return channel.pipeline.addHandlers([ - handshakeHandler, - SOCKSTestHandler(handshakeHandler: handshakeHandler), - TestHTTPServer( - expectedURL: expectedURL, - expectedResponse: expectedResponse, - file: file, - line: line - ), - ]) + channel.eventLoop.makeCompletedFuture { + let handshakeHandler = SOCKSServerHandshakeHandler() + try channel.pipeline.syncOperations.addHandlers([ + handshakeHandler, + SOCKSTestHandler(handshakeHandler: handshakeHandler), + TestHTTPServer( + expectedURL: expectedURL, + expectedResponse: expectedResponse, + file: file, + line: line + ), + ]) + } } } self.channel = try bootstrap.bind(host: "localhost", port: 0).wait() @@ -112,15 +116,19 @@ class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler { ), promise: nil ) - context.channel.pipeline.addHandlers( - [ - ByteToMessageHandler(HTTPRequestDecoder()), - HTTPResponseEncoder(), - ], - position: .after(self) - ).whenSuccess { - context.channel.pipeline.removeHandler(self, promise: nil) - context.channel.pipeline.removeHandler(self.handshakeHandler, promise: nil) + + do { + try context.channel.pipeline.syncOperations.addHandlers( + [ + ByteToMessageHandler(HTTPRequestDecoder()), + HTTPResponseEncoder(), + ], + position: .after(self) + ) + context.channel.pipeline.syncOperations.removeHandler(self, promise: nil) + context.channel.pipeline.syncOperations.removeHandler(self.handshakeHandler, promise: nil) + } catch { + context.fireErrorCaught(error) } } } diff --git a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift index 96cdf68f6..988ba6e3f 100644 --- a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift @@ -39,7 +39,7 @@ class TLSEventsHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [tlsEventsHandler]) XCTAssertNotNil(tlsEventsHandler.tlsEstablishedFuture) - XCTAssertNoThrow(try embedded.pipeline.removeHandler(tlsEventsHandler).wait()) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.removeHandler(tlsEventsHandler).wait()) XCTAssertThrowsError(try XCTUnwrap(tlsEventsHandler.tlsEstablishedFuture).wait()) } diff --git a/Tests/AsyncHTTPClientTests/TransactionTests.swift b/Tests/AsyncHTTPClientTests/TransactionTests.swift index 34349496d..3316de370 100644 --- a/Tests/AsyncHTTPClientTests/TransactionTests.swift +++ b/Tests/AsyncHTTPClientTests/TransactionTests.swift @@ -29,12 +29,10 @@ typealias PreparedRequest = HTTPClientRequest.Prepared @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class TransactionTests: XCTestCase { func testCancelAsyncRequest() { - // creating the `XCTestExpectation` off the main thread crashes on Linux with Swift 5.6 - // therefore we create it here as a workaround which works fine - let scheduledRequestCanceled = self.expectation(description: "scheduled request canceled") XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + let scheduledRequestCanceled = loop.makePromise(of: Void.self) + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -45,11 +43,11 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let queuer = MockTaskQueuer { _ in - scheduledRequestCanceled.fulfill() + scheduledRequestCanceled.succeed() } transaction.requestWasQueued(queuer) @@ -64,16 +62,14 @@ final class TransactionTests: XCTestCase { } // self.fulfillment(of:) is not available on Linux - _ = { - self.wait(for: [scheduledRequestCanceled], timeout: 1) - }() + try await scheduledRequestCanceled.futureResult.timeout(after: .seconds(1)).get() } } func testDeadlineExceededWhileQueuedAndExecutorImmediatelyCancelsTask() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -84,7 +80,7 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let queuer = MockTaskQueuer() @@ -127,8 +123,8 @@ final class TransactionTests: XCTestCase { func testResponseStreamingWorks() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -140,12 +136,12 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) transaction.willExecuteRequest(executor) @@ -186,8 +182,8 @@ final class TransactionTests: XCTestCase { func testIgnoringResponseBodyWorks() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -199,7 +195,7 @@ final class TransactionTests: XCTestCase { } var tuple: (Transaction, Task)! = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let transaction = tuple.0 @@ -208,9 +204,10 @@ final class TransactionTests: XCTestCase { let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) executor.runRequest(transaction) + await loop.run() let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) XCTAssertFalse(executor.signalledDemandForResponseBody) @@ -234,8 +231,8 @@ final class TransactionTests: XCTestCase { func testWriteBackpressureWorks() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } let streamWriter = AsyncSequenceWriter() XCTAssertFalse(streamWriter.hasDemand, "Did not expect to have a demand at this point") @@ -251,12 +248,13 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() for i in 0..<100 { XCTAssertFalse(streamWriter.hasDemand, "Did not expect to have demand yet") @@ -316,7 +314,7 @@ final class TransactionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? + var maybeHTTP2Connection: HTTP2Connection.SendableView? XCTAssertNoThrow( maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( to: httpBin.port, @@ -364,8 +362,8 @@ final class TransactionTests: XCTestCase { func testSimplePostRequest() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .POST @@ -377,11 +375,12 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() executor.resumeRequestBodyStream() XCTAssertNoThrow( try executor.receiveRequestBody { @@ -403,8 +402,8 @@ final class TransactionTests: XCTestCase { func testPostStreamFails() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } let writer = AsyncSequenceWriter() @@ -418,11 +417,12 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() executor.resumeRequestBodyStream() await writer.demand() @@ -447,8 +447,8 @@ final class TransactionTests: XCTestCase { func testResponseStreamFails() { XCTAsyncTest(timeout: 30) { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -460,12 +460,12 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) transaction.willExecuteRequest(executor) @@ -518,7 +518,7 @@ final class TransactionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? + var maybeHTTP2Connection: HTTP2Connection.SendableView? XCTAssertNoThrow( maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( to: httpBin.port, @@ -586,22 +586,31 @@ final class TransactionTests: XCTestCase { // tasks. Since we want to wait for things to happen in tests, we need to `async let`, which creates // implicit tasks. Therefore we need to wrap our iterator struct. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -actor SharedIterator where Wrapped.Element: Sendable { - private var wrappedIterator: Wrapped.AsyncIterator - private var nextCallInProgress: Bool = false +final class SharedIterator: Sendable where Wrapped.Element: Sendable { + private struct State: @unchecked Sendable { + var wrappedIterator: Wrapped.AsyncIterator + var nextCallInProgress: Bool = false + } + + private let state: NIOLockedValueBox init(_ sequence: Wrapped) { - self.wrappedIterator = sequence.makeAsyncIterator() + self.state = NIOLockedValueBox(State(wrappedIterator: sequence.makeAsyncIterator())) } func next() async throws -> Wrapped.Element? { - precondition(self.nextCallInProgress == false) - self.nextCallInProgress = true - var iter = self.wrappedIterator + var iter = self.state.withLockedValue { + precondition($0.nextCallInProgress == false) + $0.nextCallInProgress = true + return $0.wrappedIterator + } + defer { - precondition(self.nextCallInProgress == true) - self.nextCallInProgress = false - self.wrappedIterator = iter + self.state.withLockedValue { + precondition($0.nextCallInProgress == true) + $0.nextCallInProgress = false + $0.wrappedIterator = iter + } } return try await iter.next() } @@ -609,7 +618,7 @@ actor SharedIterator where Wrapped.Element: Sendable { /// non fail-able promise that only supports one observer @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -private actor Promise { +private actor Promise { private enum State { case initialised case fulfilled(Value) @@ -648,6 +657,35 @@ private actor Promise { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction { + #if compiler(>=6.0) + fileprivate static func makeWithResultTask( + request: sending PreparedRequest, + requestOptions: RequestOptions = .forTests(), + logger: Logger = Logger(label: "test"), + connectionDeadline: NIODeadline = .distantFuture, + preferredEventLoop: EventLoop + ) async -> (Transaction, _Concurrency.Task) { + let transactionPromise = Promise() + let task = Task { + try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + let transaction = Transaction( + request: request, + requestOptions: requestOptions, + logger: logger, + connectionDeadline: connectionDeadline, + preferredEventLoop: preferredEventLoop, + responseContinuation: continuation + ) + Task { + await transactionPromise.fulfil(transaction) + } + } + } + + return (await transactionPromise.value, task) + } + #else fileprivate static func makeWithResultTask( request: PreparedRequest, requestOptions: RequestOptions = .forTests(), @@ -655,10 +693,17 @@ extension Transaction { connectionDeadline: NIODeadline = .distantFuture, preferredEventLoop: EventLoop ) async -> (Transaction, _Concurrency.Task) { + // It isn't sendable ... but on 6.0 and later we use 'sending'. + struct UnsafePrepareRequest: @unchecked Sendable { + var value: PreparedRequest + } + let transactionPromise = Promise() + let unsafe = UnsafePrepareRequest(value: request) let task = Task { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let request = unsafe.value let transaction = Transaction( request: request, requestOptions: requestOptions, @@ -675,4 +720,5 @@ extension Transaction { return (await transactionPromise.value, task) } + #endif }