diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1761880d..926f2fbe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,11 +18,10 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.8-jammy - - swift:5.9-jammy - - swift:5.10-noble - - swiftlang/swift:nightly-6.0-jammy - - swiftlang/swift:nightly-main-jammy + - swift:6.0-jammy + - swift:6.1-noble + - swift:6.2-noble + - swiftlang/swift:nightly-main-noble container: ${{ matrix.swift-image }} runs-on: ubuntu-latest steps: @@ -33,11 +32,16 @@ jobs: [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" swift --version + - name: Install curl for Codecov + run: apt-get update -y -q && apt-get install -y curl - name: Check out package - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Run unit tests with Thread Sanitizer + shell: bash run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread --enable-code-coverage + # https://github.com/swiftlang/swift/issues/74042 was never fixed in 5.10 and swift-crypto hits it in 6.0 as well + SANITIZE="$([[ "${SWIFT_VERSION}" =~ ^swift-(5|6\.0) ]] || echo '--sanitize=thread')" + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' ${SANITIZE} --enable-code-coverage - name: Submit code coverage uses: vapor/swift-codecov-action@v0.3 with: @@ -48,18 +52,18 @@ jobs: fail-fast: false matrix: postgres-image: - - postgres:16 - - postgres:14 - - postgres:12 + - postgres:17 + - postgres:15 + - postgres:13 include: - - postgres-image: postgres:16 + - postgres-image: postgres:17 postgres-auth: scram-sha-256 - - postgres-image: postgres:14 + - postgres-image: postgres:15 postgres-auth: md5 - - postgres-image: postgres:12 + - postgres-image: postgres:13 postgres-auth: trust container: - image: swift:5.10-noble + image: swift:6.2-noble volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: @@ -104,15 +108,15 @@ jobs: [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { path: 'postgres-nio' } - name: Run integration tests run: swift test --package-path postgres-nio --filter=^IntegrationTests - name: Check out postgres-kit dependent - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } - name: Check out fluent-postgres-driver dependent - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { repository: 'vapor/fluent-postgres-driver', path: 'fluent-postgres-driver' } - name: Use local package in dependents run: | @@ -133,14 +137,14 @@ jobs: postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 - xcode-version: - - '~14.3' - - '~15' + macos-version: + - 'macos-14' + - 'macos-15' include: - - xcode-version: '~14.3' - macos-version: 'macos-13' - - xcode-version: '~15' - macos-version: 'macos-14' + - macos-version: 'macos-14' + xcode-version: 'latest-stable' + - macos-version: 'macos-15' + xcode-version: 'latest-stable' runs-on: ${{ matrix.macos-version }} env: POSTGRES_HOSTNAME: 127.0.0.1 @@ -158,28 +162,23 @@ jobs: - name: Install Postgres, setup DB and auth, and wait for server start run: | export PATH="$(brew --prefix)/opt/${POSTGRES_FORMULA}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - # ** BEGIN ** Work around bug in both Homebrew and GHA - (brew upgrade python@3.11 || true) && (brew link --force --overwrite python@3.11 || true) - (brew upgrade python@3.12 || true) && (brew link --force --overwrite python@3.12 || true) - (brew upgrade || true) - # ** END ** Work around bug in both Homebrew and GHA brew install --overwrite "${POSTGRES_FORMULA}" brew link --overwrite --force "${POSTGRES_FORMULA}" initdb --locale=C --auth-host "${POSTGRES_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait timeout-minutes: 15 - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Run all tests run: swift test - + api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest container: swift:noble steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 # https://github.com/actions/checkout/issues/766 @@ -187,22 +186,3 @@ jobs: run: | git config --global --add safe.directory "${GITHUB_WORKSPACE}" swift package diagnose-api-breaking-changes origin/main - -# gh-codeql: -# if: ${{ false }} -# runs-on: ubuntu-latest -# container: swift:noble -# permissions: { actions: write, contents: read, security-events: write } -# steps: -# - name: Check out code -# uses: actions/checkout@v4 -# - name: Mark repo safe in non-fake global config -# run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" -# - name: Initialize CodeQL -# uses: github/codeql-action/init@v3 -# with: -# languages: swift -# - name: Perform build -# run: swift build -# - name: Run CodeQL analyze -# uses: github/codeql-action/analyze@v3 diff --git a/Benchmarks/.gitignore b/Benchmarks/.gitignore new file mode 100644 index 00000000..24e5b0a1 --- /dev/null +++ b/Benchmarks/.gitignore @@ -0,0 +1 @@ +.build diff --git a/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift new file mode 100644 index 00000000..9cc535d4 --- /dev/null +++ b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift @@ -0,0 +1,99 @@ +import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import Benchmark + +let benchmarks: @Sendable () -> Void = { + Benchmark("Lease/Release 1k requests: 50 parallel", configuration: .init(scalingFactor: .kilo)) { benchmark in + let clock = MockClock() + let factory = MockConnectionFactory(autoMaxStreams: 1) + var configuration = ConnectionPoolConfiguration() + configuration.maximumConnectionSoftLimit = 50 + configuration.maximumConnectionHardLimit = 50 + + let pool = ConnectionPool( + configuration: configuration, + idGenerator: ConnectionIDGenerator(), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup { taskGroup in + + taskGroup.addTask { + await pool.run() + } + + let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + + benchmark.startMeasurement() + + for parallel in 0..(autoMaxStreams: 1) + var configuration = ConnectionPoolConfiguration() + configuration.maximumConnectionSoftLimit = 50 + configuration.maximumConnectionHardLimit = 50 + + let pool = ConnectionPool( + configuration: configuration, + idGenerator: ConnectionIDGenerator(), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup { taskGroup in + + taskGroup.addTask { + await pool.run() + } + + let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + + benchmark.startMeasurement() + + for _ in benchmark.scaledIterations { + do { + try await pool.withConnection { connection in + blackHole(connection) + } + } catch { + fatalError() + } + } + + benchmark.stopMeasurement() + + taskGroup.cancelAll() + } + } +} diff --git a/Benchmarks/Package.swift b/Benchmarks/Package.swift new file mode 100644 index 00000000..11407176 --- /dev/null +++ b/Benchmarks/Package.swift @@ -0,0 +1,28 @@ +// swift-tools-version: 6.0 + +import PackageDescription + +let package = Package( + name: "benchmarks", + platforms: [ + .macOS("14") + ], + dependencies: [ + .package(path: "../"), + .package(url: "https://github.com/ordo-one/package-benchmark.git", from: "1.29.0"), + ], + targets: [ + .executableTarget( + name: "ConnectionPoolBenchmarks", + dependencies: [ + .product(name: "_ConnectionPoolModule", package: "postgres-nio"), + .product(name: "_ConnectionPoolTestUtils", package: "postgres-nio"), + .product(name: "Benchmark", package: "package-benchmark"), + ], + path: "Benchmarks/ConnectionPoolBenchmarks", + plugins: [ + .plugin(name: "BenchmarkPlugin", package: "package-benchmark") + ] + ), + ] +) diff --git a/Package.swift b/Package.swift index 5c83eded..ae0e8a5d 100644 --- a/Package.swift +++ b/Package.swift @@ -1,9 +1,15 @@ -// swift-tools-version:5.8 +// swift-tools-version:6.0 import PackageDescription +#if compiler(>=6.1) +let swiftSettings: [SwiftSetting] = [] +#else let swiftSettings: [SwiftSetting] = [ - .enableUpcomingFeature("StrictConcurrency"), + // Sadly the 6.0 compiler concurrency checker finds false positives. + // To be able to compile, lets reduce the language version down to 5 for 6.0 only. + .swiftLanguageMode(.v5) ] +#endif let package = Package( name: "postgres-nio", @@ -16,14 +22,15 @@ let package = Package( products: [ .library(name: "PostgresNIO", targets: ["PostgresNIO"]), .library(name: "_ConnectionPoolModule", targets: ["_ConnectionPoolModule"]), + .library(name: "_ConnectionPoolTestUtils", targets: ["_ConnectionPoolTestUtils"]), ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-collections.git", from: "1.0.4"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.59.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), - .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), + .package(url: "https://github.com/apple/swift-crypto.git", "3.9.0" ..< "5.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.5.0"), @@ -35,6 +42,7 @@ let package = Package( .target(name: "_ConnectionPoolModule"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Crypto", package: "swift-crypto"), + .product(name: "_CryptoExtras", package: "swift-crypto"), .product(name: "Logging", package: "swift-log"), .product(name: "Metrics", package: "swift-metrics"), .product(name: "NIO", package: "swift-nio"), @@ -57,6 +65,15 @@ let package = Package( path: "Sources/ConnectionPoolModule", swiftSettings: swiftSettings ), + .target( + name: "_ConnectionPoolTestUtils", + dependencies: [ + "_ConnectionPoolModule", + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + ], + path: "Sources/ConnectionPoolTestUtils", + swiftSettings: swiftSettings + ), .testTarget( name: "PostgresNIOTests", dependencies: [ @@ -70,6 +87,7 @@ let package = Package( name: "ConnectionPoolModuleTests", dependencies: [ .target(name: "_ConnectionPoolModule"), + .target(name: "_ConnectionPoolTestUtils"), .product(name: "DequeModule", package: "swift-collections"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), diff --git a/README.md b/README.md index bc56953b..fa4495e2 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,5 @@

- - - - PostgresNIO - +PostgresNIO

@@ -16,7 +12,7 @@ Continuous Integration - Swift 5.8+ + Swift 6.0+ SSWG Incubation Level: Graduated @@ -167,7 +163,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.8]: https://swift.org +[Swift 6.0]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection diff --git a/Sources/ConnectionPoolModule/ConnectionLease.swift b/Sources/ConnectionPoolModule/ConnectionLease.swift new file mode 100644 index 00000000..77591a58 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionLease.swift @@ -0,0 +1,17 @@ +public struct ConnectionLease: Sendable { + public var connection: Connection + + @usableFromInline + let _release: @Sendable (Connection) -> () + + @inlinable + public init(connection: Connection, release: @escaping @Sendable (Connection) -> Void) { + self.connection = connection + self._release = release + } + + @inlinable + public func release() { + self._release(self.connection) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 03c269ee..40d52a5a 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -88,7 +88,7 @@ public protocol ConnectionRequestProtocol: Sendable { /// A function that is called with a connection or a /// `PoolError`. - func complete(with: Result) + func complete(with: Result, ConnectionPoolError>) } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @@ -271,25 +271,27 @@ public final class ConnectionPool< } } + @inlinable public func run() async { await withTaskCancellationHandler { - #if os(Linux) || compiler(>=5.9) if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { return await withDiscardingTaskGroup() { taskGroup in await self.run(in: &taskGroup) } } - #endif return await withTaskGroup(of: Void.self) { taskGroup in await self.run(in: &taskGroup) } } onCancel: { - let actions = self.stateBox.withLockedValue { state in - state.stateMachine.triggerForceShutdown() - } + self.triggerForceShutdown() + } + } - self.runStateMachineActions(actions) + public func triggerForceShutdown() { + let actions = self.stateBox.withLockedValue { state in + state.stateMachine.triggerForceShutdown() } + self.runStateMachineActions(actions) } // MARK: - Private Methods - @@ -313,16 +315,16 @@ public final class ConnectionPool< case scheduleTimer(StateMachine.Timer) } - #if os(Linux) || compiler(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) - private func run(in taskGroup: inout DiscardingTaskGroup) async { + @inlinable + /* private */ func run(in taskGroup: inout DiscardingTaskGroup) async { for await event in self.eventStream { self.runEvent(event, in: &taskGroup) } } - #endif - private func run(in taskGroup: inout TaskGroup) async { + @inlinable + /* private */ func run(in taskGroup: inout TaskGroup) async { var running = 0 for await event in self.eventStream { running += 1 @@ -335,7 +337,8 @@ public final class ConnectionPool< } } - private func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { + @inlinable + /* private */ func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { switch event { case .makeConnection(let request): self.makeConnection(for: request, in: &taskGroup) @@ -402,8 +405,11 @@ public final class ConnectionPool< /*private*/ func runRequestAction(_ action: StateMachine.RequestAction) { switch action { case .leaseConnection(let requests, let connection): + let lease = ConnectionLease(connection: connection) { connection in + self.releaseConnection(connection) + } for request in requests { - request.complete(with: .success(connection)) + request.complete(with: .success(lease)) } case .failRequest(let request, let error): @@ -507,11 +513,7 @@ public final class ConnectionPool< await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in taskGroup.addTask { do { - #if os(Linux) || compiler(>=5.9) try await self.clock.sleep(for: timer.duration) - #else - try await self.clock.sleep(until: self.clock.now.advanced(by: timer.duration), tolerance: nil) - #endif return .timerTriggered } catch { return .timerCancelled @@ -571,20 +573,6 @@ extension PoolConfiguration { } } -#if swift(<5.9) -// This should be removed once we support Swift 5.9+ only -extension AsyncStream { - static func makeStream( - of elementType: Element.Type = Element.self, - bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded - ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { - var continuation: AsyncStream.Continuation! - let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } -} -#endif - @usableFromInline protocol TaskGroupProtocol { // We need to call this `addTask_` because some Swift versions define this @@ -593,7 +581,6 @@ protocol TaskGroupProtocol { mutating func addTask_(operation: @escaping @Sendable () async -> Void) } -#if os(Linux) || swift(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) extension DiscardingTaskGroup: TaskGroupProtocol { @inlinable @@ -601,7 +588,6 @@ extension DiscardingTaskGroup: TaskGroupProtocol { self.addTask(priority: nil, operation: operation) } } -#endif extension TaskGroup: TaskGroupProtocol { @inlinable diff --git a/Sources/ConnectionPoolModule/ConnectionPoolError.swift b/Sources/ConnectionPoolModule/ConnectionPoolError.swift index 1f1e1d2c..3abfe778 100644 --- a/Sources/ConnectionPoolModule/ConnectionPoolError.swift +++ b/Sources/ConnectionPoolModule/ConnectionPoolError.swift @@ -1,16 +1,25 @@ public struct ConnectionPoolError: Error, Hashable { - enum Base: Error, Hashable { + @usableFromInline + enum Base: Error, Hashable, Sendable { case requestCancelled case poolShutdown } - private let base: Base + @usableFromInline + let base: Base + @inlinable init(_ base: Base) { self.base = base } /// The connection requests got cancelled - public static let requestCancelled = ConnectionPoolError(.requestCancelled) + @inlinable + public static var requestCancelled: Self { + ConnectionPoolError(.requestCancelled) + } /// The connection requests can't be fulfilled as the pool has already been shutdown - public static let poolShutdown = ConnectionPoolError(.poolShutdown) + @inlinable + public static var poolShutdown: Self { + ConnectionPoolError(.poolShutdown) + } } diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 19ed9bd2..d6654a27 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -5,23 +5,24 @@ public struct ConnectionRequest: ConnectionRequest public var id: ID @usableFromInline - private(set) var continuation: CheckedContinuation + private(set) var continuation: CheckedContinuation, any Error> @inlinable init( id: Int, - continuation: CheckedContinuation + continuation: CheckedContinuation, any Error> ) { self.id = id self.continuation = continuation } - public func complete(with result: Result) { + public func complete(with result: Result, ConnectionPoolError>) { self.continuation.resume(with: result) } } -fileprivate let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() +@usableFromInline +let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) extension ConnectionPool where Request == ConnectionRequest { @@ -44,7 +45,8 @@ extension ConnectionPool where Request == ConnectionRequest { ) } - public func leaseConnection() async throws -> Connection { + @inlinable + public func leaseConnection() async throws -> ConnectionLease { let requestID = requestIDGenerator.next() let connection = try await withTaskCancellationHandler { @@ -52,7 +54,7 @@ extension ConnectionPool where Request == ConnectionRequest { throw CancellationError() } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, Error>) in let request = Request( id: requestID, continuation: continuation @@ -67,9 +69,10 @@ extension ConnectionPool where Request == ConnectionRequest { return connection } + @inlinable public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() - defer { self.releaseConnection(connection) } - return try await closure(connection) + let lease = try await self.leaseConnection() + defer { lease.release() } + return try await closure(lease.connection) } } diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift index dbc7dbe9..b6cd7164 100644 --- a/Sources/ConnectionPoolModule/NIOLock.swift +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -24,6 +24,13 @@ import WinSDK import Glibc #elseif canImport(Musl) import Musl +#elseif canImport(Bionic) +import Bionic +#elseif canImport(WASILibc) +import WASILibc +#if canImport(wasi_pthread) +import wasi_pthread +#endif #else #error("The concurrency NIOLock module was unable to identify your C library.") #endif @@ -37,61 +44,61 @@ typealias LockPrimitive = pthread_mutex_t #endif @usableFromInline -enum LockOperations { } +enum LockOperations {} extension LockOperations { @inlinable static func create(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) InitializeSRWLock(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) var attr = pthread_mutexattr_t() pthread_mutexattr_init(&attr) debugOnly { pthread_mutexattr_settype(&attr, .init(PTHREAD_MUTEX_ERRORCHECK)) } - + let err = pthread_mutex_init(mutex, &attr) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func destroy(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) // SRWLOCK does not need to be free'd -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_destroy(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func lock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) AcquireSRWLockExclusive(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_lock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func unlock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) ReleaseSRWLockExclusive(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_unlock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } } @@ -125,49 +132,52 @@ extension LockOperations { // See also: https://github.com/apple/swift/pull/40000 @usableFromInline final class LockStorage: ManagedBuffer { - + @inlinable static func create(value: Value) -> Self { let buffer = Self.create(minimumCapacity: 1) { _ in - return value + value } - let storage = unsafeDowncast(buffer, to: Self.self) - + // Intentionally using a force cast here to avoid a miss compiliation in 5.10. + // This is as fast as an unsafeDownCast since ManagedBuffer is inlined and the optimizer + // can eliminate the upcast/downcast pair + let storage = buffer as! Self + storage.withUnsafeMutablePointers { _, lockPtr in LockOperations.create(lockPtr) } - + return storage } - + @inlinable func lock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.lock(lockPtr) } } - + @inlinable func unlock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.unlock(lockPtr) } } - + @inlinable deinit { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.destroy(lockPtr) } } - + @inlinable func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { try self.withUnsafeMutablePointerToElements { lockPtr in - return try body(lockPtr) + try body(lockPtr) } } - + @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { try self.withUnsafeMutablePointers { valuePtr, lockPtr in @@ -178,21 +188,18 @@ final class LockStorage: ManagedBuffer { } } -extension LockStorage: @unchecked Sendable { } - /// A threading lock based on `libpthread` instead of `libdispatch`. /// -/// - note: ``NIOLock`` has reference semantics. +/// - Note: ``NIOLock`` has reference semantics. /// /// This object provides a lock on top of a single `pthread_mutex_t`. This kind /// of lock is safe to use with `libpthread`-based threading models, such as the /// one used by NIO. On Windows, the lock is based on the substantially similar /// `SRWLOCK` type. -@usableFromInline struct NIOLock { @usableFromInline internal let _storage: LockStorage - + /// Create a new lock. @inlinable init() { @@ -219,7 +226,7 @@ struct NIOLock { @inlinable internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { - return try self._storage.withLockPrimitive(body) + try self._storage.withLockPrimitive(body) } } @@ -242,12 +249,12 @@ extension NIOLock { } @inlinable - func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + func withLockVoid(_ body: () throws -> Void) rethrows { try self.withLock(body) } } -extension NIOLock: Sendable {} +extension NIOLock: @unchecked Sendable {} extension UnsafeMutablePointer { @inlinable @@ -263,6 +270,10 @@ extension UnsafeMutablePointer { /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - // FIXME: duplicated with NIO. - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } diff --git a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift index e5a3e6a2..c9cd89e0 100644 --- a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift +++ b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift @@ -17,7 +17,7 @@ /// Provides locked access to `Value`. /// -/// - note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` +/// - Note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` /// alongside a lock behind a reference. /// /// This is no different than creating a ``Lock`` and protecting all @@ -39,8 +39,48 @@ struct NIOLockedValueBox { /// Access the `Value`, allowing mutation of it. @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { - return try self._storage.withLockedValue(mutate) + try self._storage.withLockedValue(mutate) + } + + /// Provides an unsafe view over the lock and its value. + /// + /// This can be beneficial when you require fine grained control over the lock in some + /// situations but don't want lose the benefits of ``withLockedValue(_:)`` in others by + /// switching to ``NIOLock``. + var unsafe: Unsafe { + Unsafe(_storage: self._storage) + } + + /// Provides an unsafe view over the lock and its value. + struct Unsafe { + @usableFromInline + let _storage: LockStorage + + /// Manually acquire the lock. + @inlinable + func lock() { + self._storage.lock() + } + + /// Manually release the lock. + @inlinable + func unlock() { + self._storage.unlock() + } + + /// Mutate the value, assuming the lock has been acquired manually. + /// + /// - Parameter mutate: A closure with scoped access to the value. + /// - Returns: The result of the `mutate` closure. + @inlinable + func withValueAssumingLockIsAcquired( + _ mutate: (_ value: inout Value) throws -> Result + ) rethrows -> Result { + try self._storage.withUnsafeMutablePointerToHeader { value in + try mutate(&value.pointee) + } + } } } -extension NIOLockedValueBox: Sendable where Value: Sendable {} +extension NIOLockedValueBox: @unchecked Sendable where Value: Sendable {} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index f26f244d..a8e97ffd 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -132,6 +132,12 @@ extension PoolStateMachine { @usableFromInline var info: ConnectionAvailableInfo + + @inlinable + init(use: ConnectionUse, info: ConnectionAvailableInfo) { + self.use = use + self.info = info + } } mutating func refillConnections() -> [ConnectionRequest] { @@ -623,7 +629,7 @@ extension PoolStateMachine { // MARK: - Private functions - - @usableFromInline + @inlinable /*private*/ func getConnectionUse(index: Int) -> ConnectionUse { switch index { case 0.. AvailableConnectionContext { precondition(self.connections[index].isAvailable) let use = self.getConnectionUse(index: index) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index 2fb68a2d..9912f13a 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -164,7 +164,7 @@ extension PoolStateMachine { } } - @usableFromInline + @inlinable var isLeased: Bool { switch self.state { case .leased: @@ -174,7 +174,7 @@ extension PoolStateMachine { } } - @usableFromInline + @inlinable var isConnected: Bool { switch self.state { case .idle, .leased: diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 3b996033..8d995fa2 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -1,7 +1,9 @@ #if canImport(Darwin) import Darwin -#else +#elseif canImport(Glibc) import Glibc +#elseif canImport(Musl) +import Musl #endif @usableFromInline @@ -432,6 +434,7 @@ struct PoolStateMachine< fatalError("Unimplemented") } + @usableFromInline mutating func triggerForceShutdown() -> Action { switch self.poolState { case .running: diff --git a/Sources/ConnectionPoolModule/TinyFastSequence.swift b/Sources/ConnectionPoolModule/TinyFastSequence.swift index dff8a30b..df140c98 100644 --- a/Sources/ConnectionPoolModule/TinyFastSequence.swift +++ b/Sources/ConnectionPoolModule/TinyFastSequence.swift @@ -29,6 +29,12 @@ struct TinyFastSequence: Sequence { self.base = .none(reserveCapacity: 0) case 1: self.base = .one(collection.first!, reserveCapacity: 0) + case 2: + self.base = .two( + collection.first!, + collection[collection.index(after: collection.startIndex)], + reserveCapacity: 0 + ) default: if let collection = collection as? Array { self.base = .n(collection) @@ -46,7 +52,7 @@ struct TinyFastSequence: Sequence { case 1: self.base = .one(max2Sequence.first!, reserveCapacity: 0) case 2: - self.base = .n(Array(max2Sequence)) + self.base = .two(max2Sequence.first!, max2Sequence.second!, reserveCapacity: 0) default: fatalError() } @@ -169,7 +175,7 @@ struct TinyFastSequence: Sequence { case .n(let array): if self.index < array.endIndex { - defer { self.index += 1} + defer { self.index += 1 } return array[self.index] } return nil diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift b/Sources/ConnectionPoolTestUtils/MockClock.swift similarity index 84% rename from Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift rename to Sources/ConnectionPoolTestUtils/MockClock.swift index cd08d54e..34bf17e3 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift +++ b/Sources/ConnectionPoolTestUtils/MockClock.swift @@ -1,31 +1,32 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import Atomics import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockClock: Clock { - struct Instant: InstantProtocol, Comparable { - typealias Duration = Swift.Duration +public final class MockClock: Clock { + public struct Instant: InstantProtocol, Comparable { + public typealias Duration = Swift.Duration - func advanced(by duration: Self.Duration) -> Self { + public func advanced(by duration: Self.Duration) -> Self { .init(self.base + duration) } - func duration(to other: Self) -> Self.Duration { + public func duration(to other: Self) -> Self.Duration { self.base - other.base } private var base: Swift.Duration - init(_ base: Duration) { + public init(_ base: Duration) { self.base = base } - static func < (lhs: Self, rhs: Self) -> Bool { + public static func < (lhs: Self, rhs: Self) -> Bool { lhs.base < rhs.base } - static func == (lhs: Self, rhs: Self) -> Bool { + public static func == (lhs: Self, rhs: Self) -> Bool { lhs.base == rhs.base } } @@ -58,16 +59,18 @@ final class MockClock: Clock { var continuation: CheckedContinuation } - typealias Duration = Swift.Duration + public typealias Duration = Swift.Duration - var minimumResolution: Duration { .nanoseconds(1) } + public var minimumResolution: Duration { .nanoseconds(1) } - var now: Instant { self.stateBox.withLockedValue { $0.now } } + public var now: Instant { self.stateBox.withLockedValue { $0.now } } private let stateBox = NIOLockedValueBox(State()) private let waiterIDGenerator = ManagedAtomic(0) - func sleep(until deadline: Instant, tolerance: Duration?) async throws { + public init() {} + + public func sleep(until deadline: Instant, tolerance: Duration?) async throws { let waiterID = self.waiterIDGenerator.loadThenWrappingIncrement(ordering: .relaxed) return try await withTaskCancellationHandler { @@ -131,7 +134,7 @@ final class MockClock: Clock { } @discardableResult - func nextTimerScheduled() async -> Instant { + public func nextTimerScheduled() async -> Instant { await withCheckedContinuation { (continuation: CheckedContinuation) in let instant = self.stateBox.withLockedValue { state -> Instant? in if let scheduled = state.nextDeadlines.popFirst() { @@ -149,7 +152,7 @@ final class MockClock: Clock { } } - func advance(to deadline: Instant) { + public func advance(to deadline: Instant) { let waiters = self.stateBox.withLockedValue { state -> ArraySlice in precondition(deadline > state.now, "Time can only move forward") state.now = deadline diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Sources/ConnectionPoolTestUtils/MockConnection.swift similarity index 86% rename from Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift rename to Sources/ConnectionPoolTestUtils/MockConnection.swift index f826ea04..db5c3ef7 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift +++ b/Sources/ConnectionPoolTestUtils/MockConnection.swift @@ -1,11 +1,11 @@ +import _ConnectionPoolModule import DequeModule -@testable import _ConnectionPoolModule +import NIOConcurrencyHelpers -// Sendability enforced through the lock -final class MockConnection: PooledConnection, Sendable { - typealias ID = Int +public final class MockConnection: PooledConnection, Sendable { + public typealias ID = Int - let id: ID + public let id: ID private enum State { case running([CheckedContinuation], [@Sendable ((any Error)?) -> ()]) @@ -15,11 +15,11 @@ final class MockConnection: PooledConnection, Sendable { private let lock: NIOLockedValueBox = NIOLockedValueBox(.running([], [])) - init(id: Int) { + public init(id: Int) { self.id = id } - var signalToClose: Void { + public var signalToClose: Void { get async throws { try await withCheckedThrowingContinuation { continuation in let runRightAway = self.lock.withLockedValue { state -> Bool in @@ -41,7 +41,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { let enqueued = self.lock.withLockedValue { state -> Bool in switch state { case .closed: @@ -64,7 +64,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func close() { + public func close() { let continuations = self.lock.withLockedValue { state -> [CheckedContinuation] in switch state { case .running(let continuations, let callbacks): @@ -81,7 +81,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func closeIfClosing() { + public func closeIfClosing() { let callbacks = self.lock.withLockedValue { state -> [@Sendable ((any Error)?) -> ()] in switch state { case .running, .closed: @@ -100,7 +100,7 @@ final class MockConnection: PooledConnection, Sendable { } extension MockConnection: CustomStringConvertible { - var description: String { + public var description: String { let state = self.lock.withLockedValue { $0 } return "MockConnection(id: \(self.id), state: \(state))" } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift similarity index 71% rename from Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift rename to Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift index 1c9bfff8..936b47cc 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift +++ b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift @@ -1,14 +1,15 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockConnectionFactory: Sendable where Clock.Duration == Duration { - typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator - typealias Request = ConnectionRequest - typealias KeepAliveBehavior = MockPingPongBehavior - typealias MetricsDelegate = NoOpConnectionPoolMetrics - typealias ConnectionID = Int - typealias Connection = MockConnection +public final class MockConnectionFactory: Sendable where Clock.Duration == Duration { + public typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + public typealias Request = ConnectionRequest + public typealias KeepAliveBehavior = MockPingPongBehavior + public typealias MetricsDelegate = NoOpConnectionPoolMetrics + public typealias ConnectionID = Int + public typealias Connection = MockConnection let stateBox = NIOLockedValueBox(State()) @@ -20,18 +21,33 @@ final class MockConnectionFactory: Sendable where Clo var runningConnections = [ConnectionID: Connection]() } - var pendingConnectionAttemptsCount: Int { + let autoMaxStreams: UInt16? + + public init(autoMaxStreams: UInt16? = nil) { + self.autoMaxStreams = autoMaxStreams + } + + public var pendingConnectionAttemptsCount: Int { self.stateBox.withLockedValue { $0.attempts.count } } - var runningConnections: [Connection] { + public var runningConnections: [Connection] { self.stateBox.withLockedValue { Array($0.runningConnections.values) } } - func makeConnection( + public func makeConnection( id: Int, for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> ) async throws -> ConnectionAndMetadata { + if let autoMaxStreams = self.autoMaxStreams { + let connection = MockConnection(id: id) + Task { + try? await connection.signalToClose + connection.closeIfClosing() + } + return .init(connection: connection, maximalStreamsOnConnection: autoMaxStreams) + } + // we currently don't support cancellation when creating a connection let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in @@ -52,7 +68,7 @@ final class MockConnectionFactory: Sendable where Clo } @discardableResult - func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + public func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in if let attempt = state.attempts.popFirst() { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift similarity index 82% rename from Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift rename to Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift index 637f096c..de1a7275 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift +++ b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift @@ -1,9 +1,10 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockPingPongBehavior: ConnectionKeepAliveBehavior { - let keepAliveFrequency: Duration? +public final class MockPingPongBehavior: ConnectionKeepAliveBehavior { + public let keepAliveFrequency: Duration? let stateBox = NIOLockedValueBox(State()) @@ -13,11 +14,11 @@ final class MockPingPongBehavior: ConnectionKeepAl var waiter = Deque), Never>>() } - init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { + public init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { self.keepAliveFrequency = keepAliveFrequency } - func runKeepAlive(for connection: Connection) async throws { + public func runKeepAlive(for connection: Connection) async throws { precondition(self.keepAliveFrequency != nil) // we currently don't support cancellation when creating a connection @@ -40,7 +41,7 @@ final class MockPingPongBehavior: ConnectionKeepAl } @discardableResult - func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { + public func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { let (connection, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(Connection, CheckedContinuation), Never>) in let run = self.stateBox.withLockedValue { state -> (Connection, CheckedContinuation)? in if let run = state.runs.popFirst() { diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift new file mode 100644 index 00000000..3dd8b0fb --- /dev/null +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -0,0 +1,27 @@ +import _ConnectionPoolModule + +public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { + public struct ID: Hashable, Sendable { + var objectID: ObjectIdentifier + + init(_ request: MockRequest) { + self.objectID = ObjectIdentifier(request) + } + } + + public init(connectionType: Connection.Type = Connection.self) {} + + public var id: ID { ID(self) } + + public static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { + lhs.id == rhs.id + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + public func complete(with: Result, ConnectionPoolError>) { + + } +} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index dd0f5404..b260723a 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -192,9 +192,22 @@ extension PostgresConnection { /// - Parameters: /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). - /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + /// - tls: The TLS mode to use. + public init(establishedChannel channel: Channel, tls: PostgresConnection.Configuration.TLS, username: String, password: String?, database: String?) { + self.init(endpointInfo: .configureChannel(channel), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for establishing a connection to a Postgres server over a preestablished + /// `NIOCore/Channel`. + /// + /// This is provided for calling code which wants to manage the underlying connection transport on its + /// own, such as when tunneling a connection through SSH. + /// + /// - Parameters: + /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an + /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). public init(establishedChannel channel: Channel, username: String, password: String?, database: String?) { - self.init(endpointInfo: .configureChannel(channel), tls: .disable, username: username, password: password, database: database) + self.init(establishedChannel: channel, tls: .disable, username: username, password: password, database: database) } // MARK: - Implementation details diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index a6efcfdf..fc48fa31 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -60,18 +60,18 @@ public final class PostgresConnection: @unchecked Sendable { func start(configuration: InternalConfiguration) -> EventLoopFuture { // 1. configure handlers - let configureSSLCallback: ((Channel) throws -> ())? + let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> ())? switch configuration.tls.base { case .prefer(let context), .require(let context): - configureSSLCallback = { channel in + configureSSLCallback = { channel, postgresChannelHandler in channel.eventLoop.assertInEventLoop() let sslHandler = try NIOSSLClientHandler( context: context, serverHostname: configuration.serverNameForTLS ) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(postgresChannelHandler)) } case .disable: configureSSLCallback = nil @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -239,8 +239,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -256,8 +255,7 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.write(.extendedQuery(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -265,8 +263,7 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.write(.closeCommand(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.closeCommand(context), promise: nil) return promise.futureResult } @@ -429,7 +426,7 @@ extension PostgresConnection { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -458,11 +455,7 @@ extension PostgresConnection { let task = HandlerTask.startListening(listener) - let writePromise = self.channel.eventLoop.makePromise(of: Void.self) - self.channel.write(task, promise: writePromise) - writePromise.futureResult.whenFailure { error in - listener.failed(error) - } + self.channel.write(task, promise: nil) } } onCancel: { let task = HandlerTask.cancelListening(channel, id) @@ -487,9 +480,7 @@ extension PostgresConnection { logger: logger, promise: promise )) - - self.write(task, cascadingFailureTo: promise) - + self.channel.write(task, promise: nil) do { return try await promise.futureResult .map { $0.asyncSequence() } @@ -524,9 +515,7 @@ extension PostgresConnection { logger: logger, promise: promise )) - - self.write(task, cascadingFailureTo: promise) - + self.channel.write(task, promise: nil) do { return try await promise.futureResult .map { $0.commandTag } @@ -542,10 +531,55 @@ extension PostgresConnection { } } - private func write(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise) { - let writePromise = self.channel.eventLoop.makePromise(of: Void.self) - self.channel.write(task, promise: writePromise) - writePromise.futureResult.cascadeFailure(to: promise) + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + _ process: (PostgresConnection) async throws -> sending Result + ) async throws -> sending Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } } } @@ -691,7 +725,7 @@ internal enum PostgresCommands: PostgresRequest { /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. public final class PostgresListenContext: Sendable { - let promise: EventLoopPromise + private let promise: EventLoopPromise var future: EventLoopFuture { self.promise.futureResult @@ -730,7 +764,8 @@ extension PostgresConnection { closure: notificationHandler ) - self.write(.startListening(listener), cascadingFailureTo: listenContext.promise) + let task = HandlerTask.startListening(listener) + self.channel.write(task, promise: nil) listenContext.future.whenComplete { _ in let task = HandlerTask.cancelListening(channel, id) @@ -777,4 +812,3 @@ extension PostgresConnection { #endif } } - diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index dda76197..38914a04 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -1,16 +1,19 @@ { "theme": { - "aside": { "border-radius": "6px", "border-style": "double", "border-width": "3px" }, + "aside": { "border-radius": "16px", "border-style": "double", "border-width": "3px" }, "border-radius": "0", "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "color": { + "fill": { "dark": "#000", "light": "#fff" }, "psqlnio": "#336791", "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", "documentation-intro-accent": "var(--color-psqlnio)", + "documentation-intro-eyebrow": "white", + "documentation-intro-figure": "white", + "documentation-intro-title": "white", "logo-base": { "dark": "#fff", "light": "#000" }, - "logo-shape": { "dark": "#000", "light": "#fff" }, - "fill": { "dark": "#000", "light": "#fff" } + "logo-shape": { "dark": "#000", "light": "#fff" } }, "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } }, diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9d264bcc..8560b948 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -752,6 +752,12 @@ struct ConnectionStateMachine { return self.modify(with: action) } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> ConnectionAction { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 78f0d202..5708b6b9 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -10,7 +10,8 @@ struct ExtendedQueryStateMachine { case parameterDescriptionReceived(ExtendedQueryContext) case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column]) case noDataMessageReceived(ExtendedQueryContext) - + case emptyQueryResponseReceived + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) @@ -90,7 +91,7 @@ struct ExtendedQueryStateMachine { mutating func cancel() -> Action { switch self.state { case .initialized: - preconditionFailure("Start must be called immediatly after the query was created") + preconditionFailure("Start must be called immediately after the query was created") case .messagesSent(let queryContext), .parseCompleteReceived(let queryContext), @@ -122,7 +123,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(.queryCancelled, read: true) } - case .commandComplete, .error, .drain: + case .commandComplete, .emptyQueryResponseReceived, .error, .drain: // the stream has already finished. return .wait @@ -229,6 +230,7 @@ struct ExtendedQueryStateMachine { .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, + .emptyQueryResponseReceived, .bindCompleteReceived, .streaming, .drain, @@ -268,6 +270,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, @@ -285,7 +288,7 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return self.avoidingStateMachineCoW { state -> Action in state = .commandComplete(commandTag: commandTag) - let result = QueryResult(value: .noRows(commandTag), logger: context.logger) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) return .succeedQuery(eventLoopPromise, with: result) } @@ -309,6 +312,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .commandComplete, .error: @@ -318,8 +322,29 @@ struct ExtendedQueryStateMachine { } } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> Action { + return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> Action { - preconditionFailure("Unimplemented") + guard case .bindCompleteReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), + .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement(_, _, _, _): + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } } mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { @@ -336,7 +361,7 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .streaming, .drain: return self.setAndFireError(error) - case .commandComplete: + case .commandComplete, .emptyQueryResponseReceived: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) case .error: preconditionFailure(""" @@ -382,6 +407,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: preconditionFailure("Requested to consume next row without anything going on.") @@ -405,6 +431,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: return .wait @@ -449,6 +476,7 @@ struct ExtendedQueryStateMachine { } case .initialized, .commandComplete, + .emptyQueryResponseReceived, .drain, .error: // we already have the complete stream received, now we are waiting for a @@ -495,7 +523,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(error, read: true) } - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: preconditionFailure(""" This state must not be reached. If the query `.isComplete`, the ConnectionStateMachine must not send any further events to the substate machine. @@ -507,7 +535,7 @@ struct ExtendedQueryStateMachine { var isComplete: Bool { switch self.state { - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: return true case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _): diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 41091ab3..7e8376a7 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -29,6 +29,12 @@ extension String: PostgresDecodable { context: PostgresDecodingContext ) throws { switch (format, type) { + case (.binary, .jsonb): + // Discard the version byte + guard let version = buffer.readInteger(as: UInt8.self), version == 1 else { + throw PostgresDecodingError.Code.failure + } + self = buffer.readString(length: buffer.readableBytes)! case (_, .varchar), (_, .bpchar), (_, .text), @@ -36,13 +42,20 @@ extension String: PostgresDecodable { // we can force unwrap here, since this method only fails if there are not enough // bytes available. self = buffer.readString(length: buffer.readableBytes)! + case (_, .uuid): guard let uuid = try? UUID(from: &buffer, type: .uuid, format: format, context: context) else { throw PostgresDecodingError.Code.failure } self = uuid.uuidString + default: - throw PostgresDecodingError.Code.typeMismatch + // We should eagerly try to convert any datatype into a String. For example the oid + // for ltree isn't static. For this reason we should just try to convert anything. + guard let string = buffer.readString(length: buffer.readableBytes) else { + throw PostgresDecodingError.Code.typeMismatch + } + self = string } } } diff --git a/Sources/PostgresNIO/New/Messages/CopyInMessage.swift b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift new file mode 100644 index 00000000..46dec648 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift @@ -0,0 +1,44 @@ +extension PostgresBackendMessage { + struct CopyInResponse: Hashable { + enum Format: Int8 { + case textual = 0 + case binary = 1 + } + + let format: Format + let columnFormats: [Format] + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes) + } + guard let format = Format(rawValue: rawFormat) else { + throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat) + } + + guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes) + } + var columnFormatCodes: [Format] = [] + columnFormatCodes.reserveCapacity(Int(numColumns)) + + for _ in 0.. (stream: AsyncThrowingStream, continuation: AsyncThrowingStream.Continuation) where Failure == Error { - var continuation: AsyncThrowingStream.Continuation! - let stream = AsyncThrowingStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } - } - #endif diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b7f2d4fb..ee925d0e 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -3,7 +3,7 @@ import Logging struct QueryResult { enum Value: Equatable { - case noRows(String) + case noRows(PSQLRowStream.StatementSummary) case rowDescription([RowDescription.Column]) } @@ -16,25 +16,30 @@ struct QueryResult { final class PSQLRowStream: @unchecked Sendable { private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source + enum StatementSummary: Equatable { + case tag(String) + case emptyResponse + } + enum Source { case stream([RowDescription.Column], PSQLRowsDataSource) - case noRows(Result) + case noRows(Result) } let eventLoop: EventLoop let logger: Logger - + private enum BufferState { case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) - case finished(buffer: CircularBuffer, commandTag: String) + case finished(buffer: CircularBuffer, summary: StatementSummary) case failure(Error) } - + private enum DownstreamState { case waitingForConsumer(BufferState) case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) - case consumed(Result) + case consumed(Result) case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) } @@ -52,9 +57,9 @@ final class PSQLRowStream: @unchecked Sendable { case .stream(let rowDescription, let dataSource): self.rowDescription = rowDescription bufferState = .streaming(buffer: .init(), dataSource: dataSource) - case .noRows(.success(let commandTag)): + case .noRows(.success(let summary)): self.rowDescription = [] - bufferState = .finished(buffer: .init(), commandTag: commandTag) + bufferState = .finished(buffer: .init(), summary: summary) case .noRows(.failure(let error)): self.rowDescription = [] bufferState = .failure(error) @@ -98,12 +103,12 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish) self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): _ = source.yield(contentsOf: buffer) source.finish() onFinish() - self.downstreamState = .consumed(.success(commandTag)) - + self.downstreamState = .consumed(.success(summary)) + case .failure(let error): source.finish(error) self.downstreamState = .consumed(.failure(error)) @@ -190,12 +195,12 @@ final class PSQLRowStream: @unchecked Sendable { dataSource.request(for: self) return promise.futureResult - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): let rows = buffer.map { PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededFuture(rows) case .failure(let error): @@ -247,8 +252,8 @@ final class PSQLRowStream: @unchecked Sendable { } return promise.futureResult - - case .finished(let buffer, let commandTag): + + case .finished(let buffer, let summary): do { for data in buffer { let row = PostgresRow( @@ -259,7 +264,7 @@ final class PSQLRowStream: @unchecked Sendable { try onRow(row) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededVoidFuture() } catch { self.downstreamState = .consumed(.failure(error)) @@ -292,7 +297,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.finished), .waitingForConsumer(.failure): preconditionFailure("How can new rows be received, if an end was already signalled?") - + case .iteratingRows(let onRow, let promise, let dataSource): do { for data in newRows { @@ -347,25 +352,25 @@ final class PSQLRowStream: @unchecked Sendable { private func receiveEnd(_ commandTag: String) { switch self.downstreamState { case .waitingForConsumer(.streaming(buffer: let buffer, _)): - self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag)) - - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, summary: .tag(commandTag))) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(()) case .waitingForAll(let rows, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(rows) case .asyncSequence(let source, _, let onFinish): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) source.finish() onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -375,7 +380,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.streaming): self.downstreamState = .waitingForConsumer(.failure(error)) - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): @@ -391,7 +396,7 @@ final class PSQLRowStream: @unchecked Sendable { consumer.finish(error) onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -413,10 +418,15 @@ final class PSQLRowStream: @unchecked Sendable { } var commandTag: String { - guard case .consumed(.success(let commandTag)) = self.downstreamState else { + guard case .consumed(.success(let consumed)) = self.downstreamState else { preconditionFailure("commandTag may only be called if all rows have been consumed") } - return commandTag + switch consumed { + case .tag(let tag): + return tag + case .emptyResponse: + return "" + } } } diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 363f9394..6106fd21 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -1,7 +1,7 @@ import Logging import NIOCore -enum HandlerTask { +enum HandlerTask: Sendable { case extendedQuery(ExtendedQueryContext) case closeCommand(CloseCommandContext) case startListening(NotificationListener) @@ -31,7 +31,7 @@ enum PSQLTask { } } -final class ExtendedQueryContext { +final class ExtendedQueryContext: Sendable { enum Query { case unnamed(PostgresQuery, EventLoopPromise) case executeStatement(PSQLExecuteStatement, EventLoopPromise) @@ -100,14 +100,15 @@ final class PreparedStatementContext: Sendable { } } -final class CloseCommandContext { +final class CloseCommandContext: Sendable { let target: CloseTarget let logger: Logger let promise: EventLoopPromise - init(target: CloseTarget, - logger: Logger, - promise: EventLoopPromise + init( + target: CloseTarget, + logger: Logger, + promise: EventLoopPromise ) { self.target = target self.logger = logger diff --git a/Sources/PostgresNIO/New/PostgresBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift index 792beec3..030d651d 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -29,6 +29,7 @@ enum PostgresBackendMessage: Hashable { case bindComplete case closeComplete case commandComplete(String) + case copyInResponse(CopyInResponse) case dataRow(DataRow) case emptyQueryResponse case error(ErrorResponse) @@ -96,6 +97,9 @@ extension PostgresBackendMessage { } return .commandComplete(commandTag) + case .copyInResponse: + return try .copyInResponse(.decode(from: &buffer)) + case .dataRow: return try .dataRow(.decode(from: &buffer)) @@ -131,9 +135,9 @@ extension PostgresBackendMessage { case .rowDescription: return try .rowDescription(.decode(from: &buffer)) - - case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion: - preconditionFailure() + + case .copyData, .copyDone, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion: + throw PSQLPartialDecodingError.unknownMessageKind(messageID) } } } @@ -151,6 +155,8 @@ extension PostgresBackendMessage: CustomDebugStringConvertible { return ".closeComplete" case .commandComplete(let commandTag): return ".commandComplete(\(String(reflecting: commandTag)))" + case .copyInResponse(let copyInResponse): + return ".copyInResponse(\(String(reflecting: copyInResponse)))" case .dataRow(let dataRow): return ".dataRow(\(String(reflecting: dataRow)))" case .emptyQueryResponse: diff --git a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift index 6f6be7ec..155c6714 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -189,6 +189,12 @@ struct PSQLPartialDecodingError: Error { description: "Expected the integer to be positive or null, but got \(actual).", file: file, line: line) } + + static func unknownMessageKind(_ messageID: PostgresBackendMessage.ID, file: String = #fileID, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Unknown message kind: \(messageID)", + file: file, line: line) + } } extension ByteBuffer { diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index a3190aa7..bc256203 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -20,7 +20,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private var decoder: NIOSingleStepByteToMessageProcessor private var encoder: PostgresFrontendMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration - private let configureSSLCallback: ((Channel) throws -> Void)? + private let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? private var listenState = ListenStateMachine() private var preparedStatementState = PreparedStatementStateMachine() @@ -29,7 +29,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { configuration: PostgresConnection.InternalConfiguration, eventLoop: EventLoop, logger: Logger, - configureSSLCallback: ((Channel) throws -> Void)? + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop @@ -46,7 +46,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { eventLoop: EventLoop, state: ConnectionStateMachine = .init(.initialized), logger: Logger = .psqlNoOpLogger, - configureSSLCallback: ((Channel) throws -> Void)? + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? ) { self.state = state self.eventLoop = eventLoop @@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { action = self.state.closeCompletedReceived() case .commandComplete(let commandTag): action = self.state.commandCompletedReceived(commandTag) + case .copyInResponse(let copyInResponse): + action = self.state.copyInResponseReceived(copyInResponse) case .dataRow(let dataRow): action = self.state.dataRowReceived(dataRow) case .emptyQueryResponse: @@ -439,7 +441,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // This method must only be called, if we signalized the StateMachine before that we are // able to setup a SSL connection. do { - try self.configureSSLCallback!(context.channel) + try self.configureSSLCallback!(context.channel, self) let action = self.state.sslHandlerAdded() self.run(action, with: context) } catch { @@ -550,9 +552,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) self.rowStream = rows - case .noRows(let commandTag): + case .noRows(let summary): rows = PSQLRowStream( - source: .noRows(.success(commandTag)), + source: .noRows(.success(summary)), eventLoop: context.channel.eventLoop, logger: result.logger ) @@ -565,8 +567,13 @@ final class PostgresChannelHandler: ChannelDuplexHandler { _ cleanup: ConnectionStateMachine.ConnectionAction.CleanUpContext, context: ChannelHandlerContext ) { - self.logger.debug("Cleaning up and closing connection.", metadata: [.error: "\(cleanup.error)"]) - + // Don't log a misleading error if the client closed the connection. + if cleanup.error.code == .clientClosedConnection { + self.logger.debug("Cleaning up and closing connection.") + } else { + self.logger.debug("Cleaning up and closing connection.", metadata: [.error: "\(cleanup.error)"]) + } + // 1. fail all tasks cleanup.tasks.forEach { task in task.failWithError(cleanup.error) diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 97805418..6ca4cc27 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder { self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) } + /// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data. + /// + /// The caller of this function is expected to write the encoder's message buffer to the backend after calling this + /// function, followed by sending the actual data to the backend. + mutating func copyDataHeader(dataLength: UInt32) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength) + } + + mutating func copyDone() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0) + } + + mutating func copyFail(message: String) { + self.clearIfNeeded() + var messageBuffer = ByteBuffer() + messageBuffer.writeNullTerminatedString(message) + self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes)) + self.buffer.writeImmutableBuffer(messageBuffer) + } + mutating func sync() { self.clearIfNeeded() self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) @@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder { private enum FrontendMessageID: UInt8, Hashable, Sendable { case bind = 66 // B case close = 67 // C + case copyData = 100 // d + case copyDone = 99 // c + case copyFail = 102 // f case describe = 68 // D case execute = 69 // E case flush = 72 // H diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index b695dcfe..6449ab29 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -172,6 +172,16 @@ public struct PostgresBindings: Sendable, Hashable { try self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -181,11 +191,34 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value, context: context) + } + } + @inlinable public mutating func append(_ value: Value) { self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -195,6 +228,19 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value, context: context) + } + } + @inlinable mutating func appendUnprotected( _ value: Value, diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift deleted file mode 100644 index 71aa04dc..00000000 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ /dev/null @@ -1,1175 +0,0 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh - -#if compiler(<5.9) -extension PostgresRow { - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0) { - precondition(self.columns.count >= 1) - let columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - let column = columnIterator.next().unsafelyUnwrapped - let swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) throws -> (T0) { - try self.decode(T0.self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { - precondition(self.columns.count >= 2) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { - try self.decode((T0, T1).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { - precondition(self.columns.count >= 3) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { - try self.decode((T0, T1, T2).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { - precondition(self.columns.count >= 4) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { - try self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { - precondition(self.columns.count >= 5) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { - try self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { - precondition(self.columns.count >= 6) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { - try self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { - precondition(self.columns.count >= 7) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { - try self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { - precondition(self.columns.count >= 8) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { - precondition(self.columns.count >= 9) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { - precondition(self.columns.count >= 10) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { - precondition(self.columns.count >= 11) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { - precondition(self.columns.count >= 12) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { - precondition(self.columns.count >= 13) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { - precondition(self.columns.count >= 14) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 13 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T13.self - let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { - precondition(self.columns.count >= 15) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 13 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T13.self - let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 14 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T14.self - let r14 = try T14._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) - } -} -#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift deleted file mode 100644 index f45357d8..00000000 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ /dev/null @@ -1,215 +0,0 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh - -#if compiler(<5.9) -extension AsyncSequence where Element == PostgresRow { - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode(T0.self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode(T0.self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) - } -} -#endif diff --git a/Sources/PostgresNIO/New/PostgresTransactionError.swift b/Sources/PostgresNIO/New/PostgresTransactionError.swift new file mode 100644 index 00000000..35038446 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresTransactionError.swift @@ -0,0 +1,21 @@ +/// A wrapper around the errors that can occur during a transaction. +public struct PostgresTransactionError: Error { + + /// The file in which the transaction was started + public var file: String + /// The line in which the transaction was started + public var line: Int + + /// The error thrown when running the `BEGIN` query + public var beginError: Error? + /// The error thrown in the transaction closure + public var closureError: Error? + + /// The error thrown while rolling the transaction back. If the ``closureError`` is set, + /// but the ``rollbackError`` is empty, the rollback was successful. If the ``rollbackError`` + /// is set, the rollback failed. + public var rollbackError: Error? + + /// The error thrown while commiting the transaction. + public var commitError: Error? +} diff --git a/Sources/PostgresNIO/New/VariadicGenerics.swift b/Sources/PostgresNIO/New/VariadicGenerics.swift index 312d36dc..b284c7a2 100644 --- a/Sources/PostgresNIO/New/VariadicGenerics.swift +++ b/Sources/PostgresNIO/New/VariadicGenerics.swift @@ -1,4 +1,4 @@ -#if compiler(>=5.9) + extension PostgresRow { // --- snip TODO: Remove once bug is fixed, that disallows tuples of one @inlinable @@ -116,7 +116,8 @@ extension PostgresRow { extension AsyncSequence where Element == PostgresRow { // --- snip TODO: Remove once bug is fixed, that disallows tuples of one @inlinable - public func decode( + @preconcurrency + public func decode( _: Column.Type, context: PostgresDecodingContext, file: String = #fileID, @@ -128,7 +129,8 @@ extension AsyncSequence where Element == PostgresRow { } @inlinable - public func decode( + @preconcurrency + public func decode( _: Column.Type, file: String = #fileID, line: Int = #line @@ -137,7 +139,8 @@ extension AsyncSequence where Element == PostgresRow { } // --- snap TODO: Remove once bug is fixed, that disallows tuples of one - public func decode( + @preconcurrency + public func decode( _ columnType: (repeat each Column).Type, context: PostgresDecodingContext, file: String = #fileID, @@ -148,7 +151,8 @@ extension AsyncSequence where Element == PostgresRow { } } - public func decode( + @preconcurrency + public func decode( _ columnType: (repeat each Column).Type, file: String = #fileID, line: Int = #line @@ -170,5 +174,3 @@ enum ComputeParameterPackLength { MemoryLayout<(repeat BoolConverter.Bool)>.size / MemoryLayout.stride } } -#endif // compiler(>=5.9) - diff --git a/Sources/PostgresNIO/Pool/ConnectionFactory.swift b/Sources/PostgresNIO/Pool/ConnectionFactory.swift index 77a0c047..31343826 100644 --- a/Sources/PostgresNIO/Pool/ConnectionFactory.swift +++ b/Sources/PostgresNIO/Pool/ConnectionFactory.swift @@ -15,9 +15,9 @@ final class ConnectionFactory: Sendable { struct SSLContextCache: Sendable { enum State { case none - case producing(TLSConfiguration, [CheckedContinuation]) - case cached(TLSConfiguration, NIOSSLContext) - case failed(TLSConfiguration, any Error) + case producing([CheckedContinuation]) + case cached(NIOSSLContext) + case failed(any Error) } var state: State = .none @@ -89,6 +89,7 @@ final class ConnectionFactory: Sendable { connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout) connectionConfig.options.tlsServerName = config.options.tlsServerName connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData + connectionConfig.options.additionalStartupParameters = config.options.additionalStartupParameters return connectionConfig } @@ -105,34 +106,17 @@ final class ConnectionFactory: Sendable { let action = self.sslContextBox.withLockedValue { cache -> Action in switch cache.state { case .none: - cache.state = .producing(tlsConfiguration, [continuation]) + cache.state = .producing([continuation]) return .produce - case .cached(let cachedTLSConfiguration, let context): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - return .succeed(context) - } else { - cache.state = .producing(tlsConfiguration, [continuation]) - return .produce - } - - case .failed(let cachedTLSConfiguration, let error): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - return .fail(error) - } else { - cache.state = .producing(tlsConfiguration, [continuation]) - return .produce - } - - case .producing(let cachedTLSConfiguration, var continuations): + case .cached(let context): + return .succeed(context) + case .failed(let error): + return .fail(error) + case .producing(var continuations): continuations.append(continuation) - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - cache.state = .producing(cachedTLSConfiguration, continuations) - return .wait - } else { - cache.state = .producing(tlsConfiguration, continuations) - return .produce - } + cache.state = .producing(continuations) + return .wait } } @@ -142,10 +126,7 @@ final class ConnectionFactory: Sendable { case .produce: // TBD: we might want to consider moving this off the concurrent executor - self.reportProduceSSLContextResult( - Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}), - for: tlsConfiguration - ) + self.reportProduceSSLContextResult(Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)})) case .succeed(let context): continuation.resume(returning: context) @@ -156,7 +137,7 @@ final class ConnectionFactory: Sendable { } } - private func reportProduceSSLContextResult(_ result: Result, for tlsConfiguration: TLSConfiguration) { + private func reportProduceSSLContextResult(_ result: Result) { enum Action { case fail(any Error, [CheckedContinuation]) case succeed(NIOSSLContext, [CheckedContinuation]) @@ -171,19 +152,15 @@ final class ConnectionFactory: Sendable { case .cached, .failed: return .none - case .producing(let cachedTLSConfiguration, let continuations): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - switch result { - case .success(let context): - cache.state = .cached(cachedTLSConfiguration, context) - return .succeed(context, continuations) - - case .failure(let failure): - cache.state = .failed(cachedTLSConfiguration, failure) - return .fail(failure, continuations) - } - } else { - return .none + case .producing(let continuations): + switch result { + case .success(let context): + cache.state = .cached(context) + return .succeed(context, continuations) + + case .failure(let failure): + cache.state = .failed(failure) + return .fail(failure, continuations) } } } diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 0907f1f8..581b5113 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -106,6 +106,10 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). public var requireBackendKeyData: Bool = true + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] = [] + /// The minimum number of connections that the client shall keep open at any time, even if there is no /// demand. Default to `0`. /// @@ -289,19 +293,63 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return ConnectionAndMetadata(connection: connection, maximalStreamsOnConnection: 1) } } - /// Lease a connection for the provided `closure`'s lifetime. /// /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture /// the provided `PostgresConnection`. /// - Returns: The closure's return value. + @_disfavoredOverload public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + + defer { lease.release() } + + return try await closure(lease.connection) + } + + /// Lease a connection for the provided `closure`'s lifetime. + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withConnection( + isolation: isolated (any Actor)? = #isolation, + _ closure: (PostgresConnection) async throws -> sending Result + ) async throws -> sending Result { + let lease = try await self.leaseConnection() - defer { self.pool.releaseConnection(connection) } + defer { lease.release() } - return try await closure(connection) + return try await closure(lease.connection) + } + + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + _ closure: (PostgresConnection) async throws -> sending Result + ) async throws -> sending Result { + // for 6.0 to compile we need to explicitly forward the isolation. + try await self.withConnection(isolation: isolation) { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, isolation: isolation, closure) + } } /// Run a query on the Postgres server the client is connected to. @@ -326,7 +374,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) } - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection var logger = logger logger[postgresMetadataKey: .connectionID] = "\(connection.id)" @@ -341,12 +390,12 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(HandlerTask.extendedQuery(context), promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult.map { $0.asyncSequence(onFinish: { - self.pool.releaseConnection(connection) + lease.release() }) }.get() } catch var error as PSQLError { @@ -368,7 +417,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let logger = logger ?? Self.loggingDisabled do { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( @@ -382,11 +432,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(task, promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult - .map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) } + .map { $0.asyncSequence(onFinish: { lease.release() }) } .get() .map { try preparedStatement.decodeRow($0) } } catch var error as PSQLError { @@ -426,7 +476,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { // MARK: - Private Methods - - private func leaseConnection() async throws -> PostgresConnection { + private func leaseConnection() async throws -> ConnectionLease { if !self.runningAtomic.load(ordering: .relaxed) { self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.") } diff --git a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift index aa8215db..62fa326a 100644 --- a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift +++ b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift @@ -19,7 +19,7 @@ final class PostgresClientMetrics: ConnectionPoolObservabilityDelegate { /// A connection attempt failed with the given error. After some period of /// time ``startedConnecting(id:)`` may be called again. func connectFailed(id: ConnectionID, error: Error) { - self.logger.debug("Connection creation failed", metadata: [ + self.logger.info("Connection creation failed", metadata: [ .connectionID: "\(id)", .error: "\(String(reflecting: error))" ]) diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 01a7e61f..8de93814 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -40,7 +40,7 @@ extension PostgresDatabase { } } -public struct PostgresQueryResult { +public struct PostgresQueryResult: Sendable { public let metadata: PostgresQueryMetadata public let rows: [PostgresRow] } @@ -73,10 +73,7 @@ public struct PostgresQueryMetadata: Sendable { init?(string: String) { let parts = string.split(separator: " ") - guard parts.count >= 1 else { - return nil - } - switch parts[0] { + switch parts.first { case "INSERT": // INSERT oid rows guard parts.count == 3 else { diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index 2a717b6b..53e9c3f7 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -1,4 +1,5 @@ import Crypto +import _CryptoExtras import Foundation extension UInt8 { @@ -292,7 +293,7 @@ internal struct SHA256: SASLAuthenticationMechanism { /// authenticating user. If the closure throws, authentication /// immediately fails with the thrown error. internal init(username: String, password: @escaping () throws -> String) { - self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().data(using: .utf8)!), []) }, bindingInfo: .unsupported) + self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().utf8), []) }, bindingInfo: .unsupported) } /// Set up a server-side `SCRAM-SHA-256` authentication. @@ -338,7 +339,7 @@ internal struct SHA256_PLUS: SASLAuthenticationMechanism { /// - channelBindingData: The appropriate data associated with the RFC5056 /// channel binding specified. internal init(username: String, password: @escaping () throws -> String, channelBindingName: String, channelBindingData: [UInt8]) { - self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().data(using: .utf8)!), []) }, bindingInfo: .bind(channelBindingName, channelBindingData)) + self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().utf8), []) }, bindingInfo: .bind(channelBindingName, channelBindingData)) } /// Set up a server-side `SCRAM-SHA-256` authentication. @@ -466,8 +467,8 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { // TODO: Perform `Normalize(password)`, aka the SASLprep profile (RFC4013) of stringprep (RFC3454) // Calculate `AuthMessage`, `ClientSignature`, and `ClientProof` - let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations) - let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let saltedPassword = try Hi(string: password, salt: serverSalt, iterations: serverIterations) + let clientKey = HMAC.authenticationCode(for: Data("Client Key".utf8), using: saltedPassword) let storedKey = SHA256.hash(data: Data(clientKey)) var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) @@ -485,9 +486,11 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(.comma) guard let proofPart = SCRAMMessageParser.serialize([.p(Array(clientProof))]) else { throw SASLAuthenticationError.genericAuthenticationFailure } clientFinalMessage.append(contentsOf: proofPart) - + + let saltedPasswordBytes = saltedPassword.withUnsafeBytes { [UInt8]($0) } + // Save state and send - self.state = .clientSentFinalMessage(saltedPassword: saltedPassword, authMessage: authMessage) + self.state = .clientSentFinalMessage(saltedPassword: saltedPasswordBytes, authMessage: authMessage) return .continue(response: clientFinalMessage) } @@ -501,7 +504,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { switch incomingAttributes.first { case .v(let verifier): // Verify server signature - let serverKey = HMAC.authenticationCode(for: "Server Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let serverKey = HMAC.authenticationCode(for: Data("Server Key".utf8), using: .init(data: saltedPassword)) let serverSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: serverKey)) guard Array(serverSignature) == verifier else { @@ -585,7 +588,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { guard nonce == repeatNonce else { throw SASLAuthenticationError.genericAuthenticationFailure } // Compute client signature - let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let clientKey = HMAC.authenticationCode(for: Data("Client Key".utf8), using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) @@ -604,7 +607,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { guard storedKey == restoredKey else { throw SCRAMServerError.invalidProof } // Compute server signature - let serverKey = HMAC.authenticationCode(for: "Server Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let serverKey = HMAC.authenticationCode(for: Data("Server Key".utf8), using: .init(data: saltedPassword)) let serverSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: serverKey)) // Generate a `server-final-message` @@ -640,19 +643,12 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { HMAC() == output length of H(). ```` */ -private func Hi(string: [UInt8], salt: [UInt8], iterations: UInt32) -> [UInt8] { - let key = SymmetricKey(data: string) - var Ui = HMAC.authenticationCode(for: salt + [0x00, 0x00, 0x00, 0x01], using: key) // salt + 0x00000001 as big-endian - var Hi = Array(Ui) - - Hi.withUnsafeMutableBytes { Hibuf -> Void in - for _ in 2...iterations { - Ui = HMAC.authenticationCode(for: Data(Ui), using: key) - - Ui.withUnsafeBytes { Uibuf -> Void in - for i in 0.. SymmetricKey { + try KDF.Insecure.PBKDF2.deriveKey( + from: string, + salt: salt, + using: .sha256, + outputByteCount: 32, + unsafeUncheckedRounds: Int(iterations) + ) } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift index fb0bfce1..23165746 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift @@ -1,13 +1,14 @@ import _ConnectionPoolModule -import XCTest +import Testing -final class ConnectionIDGeneratorTests: XCTestCase { - func testGenerateConnectionIDs() async { +@Suite struct ConnectionIDGeneratorTests { + + @Test func testGenerateConnectionIDs() async { let idGenerator = ConnectionIDGenerator() - XCTAssertEqual(idGenerator.next(), 0) - XCTAssertEqual(idGenerator.next(), 1) - XCTAssertEqual(idGenerator.next(), 2) + #expect(idGenerator.next() == 0) + #expect(idGenerator.next() == 1) + #expect(idGenerator.next() == 2) await withTaskGroup(of: Void.self) { taskGroup in for _ in 0..<1000 { @@ -17,6 +18,6 @@ final class ConnectionIDGeneratorTests: XCTestCase { } } - XCTAssertEqual(idGenerator.next(), 1003) + #expect(idGenerator.next() == 1003) } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 3c0e7a6b..f3664242 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -1,12 +1,14 @@ @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils import Atomics -import XCTest import NIOEmbedded +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class ConnectionPoolTests: XCTestCase { - func test1000ConsecutiveRequestsOnSingleConnection() async { +@Suite struct ConnectionPoolTests { + + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func test1000ConsecutiveRequestsOnSingleConnection() async { let factory = MockConnectionFactory() var config = ConnectionPoolConfiguration() @@ -33,37 +35,35 @@ final class ConnectionPoolTests: XCTestCase { let createdConnection = await factory.nextConnectAttempt { _ in return 1 } - XCTAssertNotNil(createdConnection) do { for _ in 0..<1000 { - async let connectionFuture = try await pool.leaseConnection() - var leasedConnection: MockConnection? - XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) - leasedConnection = try await connectionFuture - XCTAssertNotNil(leasedConnection) - XCTAssert(createdConnection === leasedConnection) - - if let leasedConnection { - pool.releaseConnection(leasedConnection) - } + async let connectionFuture = pool.leaseConnection() + var connectionLease: ConnectionLease? + #expect(factory.pendingConnectionAttemptsCount == 0) + connectionLease = try await connectionFuture + #expect(connectionLease != nil) + #expect(createdConnection === connectionLease?.connection) + + connectionLease?.release() } } catch { - XCTFail("Unexpected error: \(error)") + Issue.record("Unexpected error: \(error)") } taskGroup.cancelAll() - XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + #expect(factory.pendingConnectionAttemptsCount == 0) for connection in factory.runningConnections { connection.closeIfClosing() } } - XCTAssertEqual(factory.runningConnections.count, 0) + #expect(factory.runningConnections.count == 0) } - func testShutdownPoolWhileConnectionIsBeingCreated() async { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testShutdownPoolWhileConnectionIsBeingCreated() async { let clock = MockClock() let factory = MockConnectionFactory() @@ -108,7 +108,8 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionCreationError: Error {} } - func testShutdownPoolWhileConnectionIsBackingOff() async { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testShutdownPoolWhileConnectionIsBackingOff() async { let clock = MockClock() let factory = MockConnectionFactory() @@ -143,7 +144,8 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionCreationError: Error {} } - func testConnectionHardLimitIsRespected() async { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testConnectionHardLimitIsRespected() async { let factory = MockConnectionFactory() var mutableConfig = ConnectionPoolConfiguration() @@ -172,21 +174,21 @@ final class ConnectionPoolTests: XCTestCase { await withTaskGroup(of: Void.self) { taskGroup in taskGroup.addTask_ { await pool.run() - XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original) + #expect(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original == false) } taskGroup.addTask_ { var usedConnectionIDs = Set() for _ in 0..() let keepAliveDuration = Duration.seconds(30) @@ -249,16 +252,16 @@ final class ConnectionPoolTests: XCTestCase { await pool.run() } - async let lease1ConnectionAsync = pool.leaseConnection() + async let connectionLeaseFuture = pool.leaseConnection() let connection = await factory.nextConnectAttempt { connectionID in return 1 } - let lease1Connection = try await lease1ConnectionAsync - XCTAssert(connection === lease1Connection) + let connectionLease = try await connectionLeaseFuture + #expect(connection === connectionLease.connection) - pool.releaseConnection(lease1Connection) + connectionLease.release() // keep alive 1 @@ -266,11 +269,11 @@ final class ConnectionPoolTests: XCTestCase { var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] let deadline1 = await clock.nextTimerScheduled() print(deadline1) - XCTAssertNotNil(expectedInstants.remove(deadline1)) + #expect(expectedInstants.remove(deadline1) != nil) let deadline2 = await clock.nextTimerScheduled() print(deadline2) - XCTAssertNotNil(expectedInstants.remove(deadline2)) - XCTAssert(expectedInstants.isEmpty) + #expect(expectedInstants.remove(deadline2) != nil) + #expect(expectedInstants.isEmpty == true) // move clock forward to keep alive let newTime = clock.now.advanced(by: keepAliveDuration) @@ -279,14 +282,14 @@ final class ConnectionPoolTests: XCTestCase { await keepAlive.nextKeepAlive { keepAliveConnection in defer { print("keep alive 1 has run") } - XCTAssertTrue(keepAliveConnection === lease1Connection) + #expect(keepAliveConnection === connectionLease.connection) return true } // keep alive 2 let deadline3 = await clock.nextTimerScheduled() - XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + #expect(deadline3 == clock.now.advanced(by: keepAliveDuration)) print(deadline3) // race keep alive vs timeout @@ -300,7 +303,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testKeepAliveOnClose() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testKeepAliveOnClose() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(20) @@ -328,16 +332,16 @@ final class ConnectionPoolTests: XCTestCase { await pool.run() } - async let lease1ConnectionAsync = pool.leaseConnection() + async let connectionLeaseFuture = pool.leaseConnection() let connection = await factory.nextConnectAttempt { connectionID in return 1 } - let lease1Connection = try await lease1ConnectionAsync - XCTAssert(connection === lease1Connection) + let connectionLease = try await connectionLeaseFuture + #expect(connection === connectionLease.connection) - pool.releaseConnection(lease1Connection) + connectionLease.release() // keep alive 1 @@ -345,38 +349,38 @@ final class ConnectionPoolTests: XCTestCase { var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] let deadline1 = await clock.nextTimerScheduled() print(deadline1) - XCTAssertNotNil(expectedInstants.remove(deadline1)) + #expect(expectedInstants.remove(deadline1) != nil) let deadline2 = await clock.nextTimerScheduled() print(deadline2) - XCTAssertNotNil(expectedInstants.remove(deadline2)) - XCTAssert(expectedInstants.isEmpty) + #expect(expectedInstants.remove(deadline2) != nil) + #expect(expectedInstants.isEmpty) // move clock forward to keep alive let newTime = clock.now.advanced(by: keepAliveDuration) clock.advance(to: newTime) await keepAlive.nextKeepAlive { keepAliveConnection in - XCTAssertTrue(keepAliveConnection === lease1Connection) + #expect(keepAliveConnection === connectionLease.connection) return true } // keep alive 2 let deadline3 = await clock.nextTimerScheduled() - XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + #expect(deadline3 == clock.now.advanced(by: keepAliveDuration)) clock.advance(to: clock.now.advanced(by: keepAliveDuration)) let failingKeepAliveDidRun = ManagedAtomic(false) // the following keep alive should not cause a crash _ = try? await keepAlive.nextKeepAlive { keepAliveConnection in defer { - XCTAssertFalse(failingKeepAliveDidRun - .compareExchange(expected: false, desired: true, ordering: .relaxed).original) + #expect(failingKeepAliveDidRun + .compareExchange(expected: false, desired: true, ordering: .relaxed).original == false) } - XCTAssertTrue(keepAliveConnection === lease1Connection) + #expect(keepAliveConnection === connectionLease.connection) keepAliveConnection.close() throw CancellationError() // any error } // will fail and it's expected - XCTAssertTrue(failingKeepAliveDidRun.load(ordering: .relaxed)) + #expect(failingKeepAliveDidRun.load(ordering: .relaxed) == true) taskGroup.cancelAll() @@ -386,7 +390,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testKeepAliveWorksRacesAgainstShutdown() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testKeepAliveWorksRacesAgainstShutdown() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -414,16 +419,16 @@ final class ConnectionPoolTests: XCTestCase { await pool.run() } - async let lease1ConnectionAsync = pool.leaseConnection() + async let connectionLeaseFuture = pool.leaseConnection() let connection = await factory.nextConnectAttempt { connectionID in return 1 } - let lease1Connection = try await lease1ConnectionAsync - XCTAssert(connection === lease1Connection) + let connectionLease = try await connectionLeaseFuture + #expect(connection === connectionLease.connection) - pool.releaseConnection(lease1Connection) + connectionLease.release() // keep alive 1 @@ -431,17 +436,17 @@ final class ConnectionPoolTests: XCTestCase { var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] let deadline1 = await clock.nextTimerScheduled() print(deadline1) - XCTAssertNotNil(expectedInstants.remove(deadline1)) + #expect(expectedInstants.remove(deadline1) != nil) let deadline2 = await clock.nextTimerScheduled() print(deadline2) - XCTAssertNotNil(expectedInstants.remove(deadline2)) - XCTAssert(expectedInstants.isEmpty) + #expect(expectedInstants.remove(deadline2) != nil) + #expect(expectedInstants.isEmpty) clock.advance(to: clock.now.advanced(by: keepAliveDuration)) await keepAlive.nextKeepAlive { keepAliveConnection in defer { print("keep alive 1 has run") } - XCTAssertTrue(keepAliveConnection === lease1Connection) + #expect(keepAliveConnection === connectionLease.connection) return true } @@ -454,7 +459,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testCancelConnectionRequestWorks() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testCancelConnectionRequestWorks() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -502,9 +508,9 @@ final class ConnectionPoolTests: XCTestCase { let taskResult = await leaseTask.result switch taskResult { case .success: - XCTFail("Expected task failure") + Issue.record("Expected task failure") case .failure(let failure): - XCTAssertEqual(failure as? ConnectionPoolError, .requestCancelled) + #expect(failure as? ConnectionPoolError == .requestCancelled) } taskGroup.cancelAll() @@ -514,7 +520,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testLeasingMultipleConnectionsAtOnceWorks() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeasingMultipleConnectionsAtOnceWorks() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -555,19 +562,19 @@ final class ConnectionPoolTests: XCTestCase { // lease 4 connections at once pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLeases = [ConnectionLease]() for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that we got 4 distinct connections - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 4) + #expect(Set(connectionLeases.lazy.map(\.connection.id)).count == 4) // release all 4 leased connections - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } // shutdown @@ -578,7 +585,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testLeasingConnectionAfterShutdownIsInvokedFails() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeasingConnectionAfterShutdownIsInvokedFails() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -619,10 +627,10 @@ final class ConnectionPoolTests: XCTestCase { do { _ = try await pool.leaseConnection() - XCTFail("Expected a failure") + Issue.record("Expected a failure") } catch { print("failed") - XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + #expect(error as? ConnectionPoolError == .poolShutdown) } print("will close connections: \(factory.runningConnections)") @@ -633,7 +641,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testLeasingConnectionsAfterShutdownIsInvokedFails() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeasingConnectionsAfterShutdownIsInvokedFails() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -681,9 +690,9 @@ final class ConnectionPoolTests: XCTestCase { for request in requests { do { _ = try await request.future.success - XCTFail("Expected a failure") + Issue.record("Expected a failure") } catch { - XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + #expect(error as? ConnectionPoolError == .poolShutdown) } } @@ -694,7 +703,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testLeasingMultipleStreamsFromOneConnectionWorks() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeasingMultipleStreamsFromOneConnectionWorks() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -726,7 +736,7 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests let requests = (0..<10).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLeases = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 10 @@ -734,15 +744,15 @@ final class ConnectionPoolTests: XCTestCase { for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + #expect(Set(connectionLeases.lazy.map(\.connection.id)).count == 1) // release all 10 leased streams - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } for _ in 0..<9 { @@ -759,7 +769,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testIncreasingAvailableStreamsWorks() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testIncreasingAvailableStreamsWorks() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -791,41 +802,41 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests var requests = (0..<21).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLease = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 1 } - let connection = try await requests.first!.future.success - connections.append(connection) + let lease = try await requests.first!.future.success + connectionLease.append(lease) requests.removeFirst() - pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) + pool.connectionReceivedNewMaxStreamSetting(lease.connection, newMaxStreamSetting: 21) for (_, request) in requests.enumerated() { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + #expect(Set(connectionLease.lazy.map(\.connection.id)).count == 1) requests = (22..<42).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) // release all 21 leased streams in a single call - pool.releaseConnection(connection, streams: 21) + pool.releaseConnection(lease.connection, streams: 21) // ensure all 20 new requests got fulfilled for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // release all 20 leased streams one by one for _ in requests { - pool.releaseConnection(connection, streams: 1) + pool.releaseConnection(lease.connection, streams: 1) } // shutdown @@ -839,14 +850,14 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionFuture: ConnectionRequestProtocol { let id: Int - let future: Future + let future: Future> init(id: Int) { self.id = id - self.future = Future(of: MockConnection.self) + self.future = Future(of: ConnectionLease.self) } - func complete(with result: Result) { + func complete(with result: Result, ConnectionPoolError>) { switch result { case .success(let success): self.future.yield(value: success) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index 5845267f..b4658df8 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -1,27 +1,29 @@ @testable import _ConnectionPoolModule -import XCTest +import _ConnectionPoolTestUtils +import Testing -final class ConnectionRequestTests: XCTestCase { +@Suite struct ConnectionRequestTests { - func testHappyPath() async throws { + @Test func testHappyPath() async throws { let mockConnection = MockConnection(id: 1) - let connection = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let lease = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, any Error>) in let request = ConnectionRequest(id: 42, continuation: continuation) - XCTAssertEqual(request.id, 42) - continuation.resume(with: .success(mockConnection)) + #expect(request.id == 42) + let lease = ConnectionLease(connection: mockConnection) { _ in } + continuation.resume(with: .success(lease)) } - XCTAssert(connection === mockConnection) + #expect(lease.connection === mockConnection) } - func testSadPath() async throws { + @Test func testSadPath() async throws { do { _ = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in continuation.resume(with: .failure(ConnectionPoolError.requestCancelled)) } - XCTFail("This point should not be reached") + Issue.record("This point should not be reached") } catch { - XCTAssertEqual(error as? ConnectionPoolError, .requestCancelled) + #expect(error as? ConnectionPoolError == .requestCancelled) } } } diff --git a/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift index 081e867b..ce620cc3 100644 --- a/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift @@ -1,60 +1,61 @@ @testable import _ConnectionPoolModule -import XCTest +import Testing -final class Max2SequenceTests: XCTestCase { - func testCountAndIsEmpty() async { +@Suite struct Max2SequenceTests { + + @Test func testCountAndIsEmpty() async { var sequence = Max2Sequence() - XCTAssertEqual(sequence.count, 0) - XCTAssertEqual(sequence.isEmpty, true) + #expect(sequence.count == 0) + #expect(sequence.isEmpty == true) sequence.append(1) - XCTAssertEqual(sequence.count, 1) - XCTAssertEqual(sequence.isEmpty, false) + #expect(sequence.count == 1) + #expect(sequence.isEmpty == false) sequence.append(2) - XCTAssertEqual(sequence.count, 2) - XCTAssertEqual(sequence.isEmpty, false) + #expect(sequence.count == 2) + #expect(sequence.isEmpty == false) } - func testOptionalInitializer() { + @Test func testOptionalInitializer() { let emptySequence = Max2Sequence(nil, nil) - XCTAssertEqual(emptySequence.count, 0) - XCTAssertEqual(emptySequence.isEmpty, true) + #expect(emptySequence.count == 0) + #expect(emptySequence.isEmpty == true) var emptySequenceIterator = emptySequence.makeIterator() - XCTAssertNil(emptySequenceIterator.next()) - XCTAssertNil(emptySequenceIterator.next()) - XCTAssertNil(emptySequenceIterator.next()) + #expect(emptySequenceIterator.next() == nil) + #expect(emptySequenceIterator.next() == nil) + #expect(emptySequenceIterator.next() == nil) let oneElemSequence1 = Max2Sequence(1, nil) - XCTAssertEqual(oneElemSequence1.count, 1) - XCTAssertEqual(oneElemSequence1.isEmpty, false) + #expect(oneElemSequence1.count == 1) + #expect(oneElemSequence1.isEmpty == false) var oneElemSequence1Iterator = oneElemSequence1.makeIterator() - XCTAssertEqual(oneElemSequence1Iterator.next(), 1) - XCTAssertNil(oneElemSequence1Iterator.next()) - XCTAssertNil(oneElemSequence1Iterator.next()) + #expect(oneElemSequence1Iterator.next() == 1) + #expect(oneElemSequence1Iterator.next() == nil) + #expect(oneElemSequence1Iterator.next() == nil) let oneElemSequence2 = Max2Sequence(nil, 2) - XCTAssertEqual(oneElemSequence2.count, 1) - XCTAssertEqual(oneElemSequence2.isEmpty, false) + #expect(oneElemSequence2.count == 1) + #expect(oneElemSequence2.isEmpty == false) var oneElemSequence2Iterator = oneElemSequence2.makeIterator() - XCTAssertEqual(oneElemSequence2Iterator.next(), 2) - XCTAssertNil(oneElemSequence2Iterator.next()) - XCTAssertNil(oneElemSequence2Iterator.next()) + #expect(oneElemSequence2Iterator.next() == 2) + #expect(oneElemSequence2Iterator.next() == nil) + #expect(oneElemSequence2Iterator.next() == nil) let twoElemSequence = Max2Sequence(1, 2) - XCTAssertEqual(twoElemSequence.count, 2) - XCTAssertEqual(twoElemSequence.isEmpty, false) + #expect(twoElemSequence.count == 2) + #expect(twoElemSequence.isEmpty == false) var twoElemSequenceIterator = twoElemSequence.makeIterator() - XCTAssertEqual(twoElemSequenceIterator.next(), 1) - XCTAssertEqual(twoElemSequenceIterator.next(), 2) - XCTAssertNil(twoElemSequenceIterator.next()) + #expect(twoElemSequenceIterator.next() == 1) + #expect(twoElemSequenceIterator.next() == 2) + #expect(twoElemSequenceIterator.next() == nil) } func testMap() { let twoElemSequence = Max2Sequence(1, 2).map({ "\($0)" }) - XCTAssertEqual(twoElemSequence.count, 2) - XCTAssertEqual(twoElemSequence.isEmpty, false) + #expect(twoElemSequence.count == 2) + #expect(twoElemSequence.isEmpty == false) var twoElemSequenceIterator = twoElemSequence.makeIterator() - XCTAssertEqual(twoElemSequenceIterator.next(), "1") - XCTAssertEqual(twoElemSequenceIterator.next(), "2") - XCTAssertNil(twoElemSequenceIterator.next()) + #expect(twoElemSequenceIterator.next() == "1") + #expect(twoElemSequenceIterator.next() == "2") + #expect(twoElemSequenceIterator.next() == nil) } } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift deleted file mode 100644 index 6aaa9c91..00000000 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift +++ /dev/null @@ -1,28 +0,0 @@ -import _ConnectionPoolModule - -final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { - typealias Connection = MockConnection - - struct ID: Hashable { - var objectID: ObjectIdentifier - - init(_ request: MockRequest) { - self.objectID = ObjectIdentifier(request) - } - } - - var id: ID { ID(self) } - - - static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { - lhs.id == rhs.id - } - - func hash(into hasher: inout Hasher) { - hasher.combine(self.id) - } - - func complete(with: Result) { - - } -} diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift index b817ce19..ef6b001a 100644 --- a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -1,10 +1,12 @@ import _ConnectionPoolModule -import XCTest +import _ConnectionPoolTestUtils +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class NoKeepAliveBehaviorTests: XCTestCase { - func testNoKeepAlive() { + +@Suite struct NoKeepAliveBehaviorTests { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testNoKeepAlive() { let keepAliveBehavior = NoOpKeepAliveBehavior(connectionType: MockConnection.self) - XCTAssertNil(keepAliveBehavior.keepAliveFrequency) + #expect(keepAliveBehavior.keepAliveFrequency == nil) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 6b8d6c6e..6bfe0f39 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -1,21 +1,12 @@ -import XCTest @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class PoolStateMachine_ConnectionGroupTests: XCTestCase { - var idGenerator: ConnectionIDGenerator! +@Suite struct PoolStateMachine_ConnectionGroupTests { + var idGenerator = ConnectionIDGenerator() - override func setUp() { - self.idGenerator = ConnectionIDGenerator() - super.setUp() - } - - override func tearDown() { - self.idGenerator = nil - super.tearDown() - } - - func testRefillConnections() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRefillConnections() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 4, @@ -25,35 +16,36 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { keepAliveReducesAvailableStreams: true ) - XCTAssertTrue(connections.isEmpty) + #expect(connections.isEmpty == true) let requests = connections.refillConnections() - XCTAssertFalse(connections.isEmpty) + #expect(connections.isEmpty == false) - XCTAssertEqual(requests.count, 4) - XCTAssertNil(connections.createNewDemandConnectionIfPossible()) - XCTAssertNil(connections.createNewOverflowConnectionIfPossible()) - XCTAssertEqual(connections.stats, .init(connecting: 4)) - XCTAssertEqual(connections.soonAvailableConnections, 4) + #expect(requests.count == 4) + #expect(connections.createNewDemandConnectionIfPossible() == nil) + #expect(connections.createNewOverflowConnectionIfPossible() == nil) + #expect(connections.stats == .init(connecting: 4)) + #expect(connections.soonAvailableConnections == 4) let requests2 = connections.refillConnections() - XCTAssertTrue(requests2.isEmpty) + #expect(requests2.isEmpty == true) var connected: UInt16 = 0 for request in requests { let newConnection = MockConnection(id: request.connectionID) let (_, context) = connections.newConnectionEstablished(newConnection, maxStreams: 1) - XCTAssertEqual(context.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(context.use, .persisted) + #expect(context.info == .idle(availableStreams: 1, newIdle: true)) + #expect(context.use == .persisted) connected += 1 - XCTAssertEqual(connections.stats, .init(connecting: 4 - connected, idle: connected, availableStreams: connected)) - XCTAssertEqual(connections.soonAvailableConnections, 4 - connected) + #expect(connections.stats == .init(connecting: 4 - connected, idle: connected, availableStreams: connected)) + #expect(connections.soonAvailableConnections == 4 - connected) } let requests3 = connections.refillConnections() - XCTAssertTrue(requests3.isEmpty) + #expect(requests3.isEmpty == true) } - func testMakeConnectionLeaseItAndDropItHappyPath() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testMakeConnectionLeaseItAndDropItHappyPath() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 0, @@ -64,70 +56,77 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) let requests = connections.refillConnections() - XCTAssertTrue(connections.isEmpty) - XCTAssertTrue(requests.isEmpty) + #expect(connections.isEmpty) + #expect(requests.isEmpty) guard let request = connections.createNewDemandConnectionIfPossible() else { - return XCTFail("Expected to receive a connection request") + Issue.record("Expected to receive a connection request") + return } - XCTAssertEqual(request, .init(connectionID: 0)) - XCTAssertFalse(connections.isEmpty) - XCTAssertEqual(connections.soonAvailableConnections, 1) - XCTAssertEqual(connections.stats, .init(connecting: 1)) + #expect(request == .init(connectionID: 0)) + #expect(!connections.isEmpty) + #expect(connections.soonAvailableConnections == 1) + #expect(connections.stats == .init(connecting: 1)) let newConnection = MockConnection(id: request.connectionID) let (_, establishedContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) - XCTAssertEqual(establishedContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedContext.use, .demand) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) - XCTAssertEqual(connections.soonAvailableConnections, 0) + #expect(establishedContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedContext.use == .demand) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) + #expect(connections.soonAvailableConnections == 0) guard case .leasedConnection(let leaseResult) = connections.leaseConnectionOrSoonAvailableConnectionCount() else { - return XCTFail("Expected to lease a connection") + Issue.record("Expected to lease a connection") + return } - XCTAssert(newConnection === leaseResult.connection) - XCTAssertEqual(connections.stats, .init(leased: 1, leasedStreams: 1)) + #expect(newConnection === leaseResult.connection) + #expect(connections.stats == .init(leased: 1, leasedStreams: 1)) guard let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) else { - return XCTFail("Expected that this connection is still active") + Issue.record("Expected that this connection is still active") + return } - XCTAssertEqual(releasedContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(releasedContext.use, .demand) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(releasedContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(releasedContext.use == .demand) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) let parkTimers = connections.parkConnection(at: index, hasBecomeIdle: true) - XCTAssertEqual(parkTimers, [ + #expect(parkTimers == [ .init(timerID: 0, connectionID: newConnection.id, usecase: .keepAlive), .init(timerID: 1, connectionID: newConnection.id, usecase: .idleTimeout), ]) guard let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) else { - return XCTFail("Expected to get a connection for ping pong") + Issue.record("Expected to get a connection for ping pong") + return } - XCTAssert(newConnection === keepAliveAction.connection) - XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + #expect(newConnection === keepAliveAction.connection) + #expect(connections.stats == .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) guard let (_, pingPongContext) = connections.keepAliveSucceeded(newConnection.id) else { - return XCTFail("Expected to get an AvailableContext") + Issue.record("Expected to get an AvailableContext") + return } - XCTAssertEqual(pingPongContext.info, .idle(availableStreams: 1, newIdle: false)) - XCTAssertEqual(releasedContext.use, .demand) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(pingPongContext.info == .idle(availableStreams: 1, newIdle: false)) + #expect(releasedContext.use == .demand) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) guard let closeAction = connections.closeConnectionIfIdle(newConnection.id) else { - return XCTFail("Expected to get a connection for ping pong") + Issue.record("Expected to get a connection for ping pong") + return } - XCTAssertEqual(closeAction.timersToCancel, []) - XCTAssert(closeAction.connection === newConnection) - XCTAssertEqual(connections.stats, .init(closing: 1, availableStreams: 0)) + #expect(closeAction.timersToCancel == []) + #expect(closeAction.connection === newConnection) + #expect(connections.stats == .init(closing: 1, availableStreams: 0)) let closeContext = connections.connectionClosed(newConnection.id) - XCTAssertEqual(closeContext.connectionsStarting, 0) - XCTAssertTrue(connections.isEmpty) - XCTAssertEqual(connections.stats, .init()) + #expect(closeContext.connectionsStarting == 0) + #expect(connections.isEmpty) + #expect(connections.stats == .init()) } - func testBackoffDoneCreatesANewConnectionToReachMinimumConnectionsEvenThoughRetryIsSetToFalse() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testBackoffDoneCreatesANewConnectionToReachMinimumConnectionsEvenThoughRetryIsSetToFalse() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 1, @@ -138,26 +137,30 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) let requests = connections.refillConnections() - XCTAssertEqual(connections.stats, .init(connecting: 1)) - XCTAssertEqual(connections.soonAvailableConnections, 1) - XCTAssertFalse(connections.isEmpty) - XCTAssertEqual(requests.count, 1) - - guard let request = requests.first else { return XCTFail("Expected to receive a connection request") } - XCTAssertEqual(request, .init(connectionID: 0)) + #expect(connections.stats == .init(connecting: 1)) + #expect(connections.soonAvailableConnections == 1) + #expect(!connections.isEmpty) + #expect(requests.count == 1) + + guard let request = requests.first else { + Issue.record("Expected to receive a connection request") + return + } + #expect(request == .init(connectionID: 0)) let backoffTimer = connections.backoffNextConnectionAttempt(request.connectionID) - XCTAssertEqual(connections.stats, .init(backingOff: 1)) + #expect(connections.stats == .init(backingOff: 1)) let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) - XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + #expect(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken) == nil) let backoffDoneAction = connections.backoffDone(request.connectionID, retry: false) - XCTAssertEqual(backoffDoneAction, .createConnection(.init(connectionID: 0), backoffTimerCancellationToken)) + #expect(backoffDoneAction == .createConnection(.init(connectionID: 0), backoffTimerCancellationToken)) - XCTAssertEqual(connections.stats, .init(connecting: 1)) + #expect(connections.stats == .init(connecting: 1)) } - func testBackoffDoneCancelsIdleTimerIfAPersistedConnectionIsNotRetried() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testBackoffDoneCancelsIdleTimerIfAPersistedConnectionIsNotRetried() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 2, @@ -168,57 +171,60 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) let requests = connections.refillConnections() - XCTAssertEqual(connections.stats, .init(connecting: 2)) - XCTAssertEqual(connections.soonAvailableConnections, 2) - XCTAssertFalse(connections.isEmpty) - XCTAssertEqual(requests.count, 2) + #expect(connections.stats == .init(connecting: 2)) + #expect(connections.soonAvailableConnections == 2) + #expect(!connections.isEmpty) + #expect(requests.count == 2) var requestIterator = requests.makeIterator() guard let firstRequest = requestIterator.next(), let secondRequest = requestIterator.next() else { - return XCTFail("Expected to get two requests") + Issue.record("Expected to get two requests") + return } guard let thirdRequest = connections.createNewDemandConnectionIfPossible() else { - return XCTFail("Expected to get another request") + Issue.record("Expected to get another request") + return } - XCTAssertEqual(connections.stats, .init(connecting: 3)) + #expect(connections.stats == .init(connecting: 3)) let newSecondConnection = MockConnection(id: secondRequest.connectionID) let (_, establishedSecondConnectionContext) = connections.newConnectionEstablished(newSecondConnection, maxStreams: 1) - XCTAssertEqual(establishedSecondConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedSecondConnectionContext.use, .persisted) - XCTAssertEqual(connections.stats, .init(connecting: 2, idle: 1, availableStreams: 1)) - XCTAssertEqual(connections.soonAvailableConnections, 2) + #expect(establishedSecondConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedSecondConnectionContext.use == .persisted) + #expect(connections.stats == .init(connecting: 2, idle: 1, availableStreams: 1)) + #expect(connections.soonAvailableConnections == 2) let newThirdConnection = MockConnection(id: thirdRequest.connectionID) let (thirdConnectionIndex, establishedThirdConnectionContext) = connections.newConnectionEstablished(newThirdConnection, maxStreams: 1) - XCTAssertEqual(establishedThirdConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedThirdConnectionContext.use, .demand) - XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 2, availableStreams: 2)) - XCTAssertEqual(connections.soonAvailableConnections, 1) + #expect(establishedThirdConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedThirdConnectionContext.use == .demand) + #expect(connections.stats == .init(connecting: 1, idle: 2, availableStreams: 2)) + #expect(connections.soonAvailableConnections == 1) let thirdConnKeepTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: thirdRequest.connectionID, usecase: .keepAlive) let thirdConnIdleTimer = TestPoolStateMachine.ConnectionTimer(timerID: 1, connectionID: thirdRequest.connectionID, usecase: .idleTimeout) let thirdConnIdleTimerCancellationToken = MockTimerCancellationToken(thirdConnIdleTimer) - XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex, hasBecomeIdle: true), [thirdConnKeepTimer, thirdConnIdleTimer]) + #expect(connections.parkConnection(at: thirdConnectionIndex, hasBecomeIdle: true) == [thirdConnKeepTimer, thirdConnIdleTimer]) - XCTAssertNil(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer))) - XCTAssertNil(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken)) + #expect(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer)) == nil) + #expect(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken) == nil) let backoffTimer = connections.backoffNextConnectionAttempt(firstRequest.connectionID) - XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + #expect(connections.stats == .init(backingOff: 1, idle: 2, availableStreams: 2)) let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) - XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) - XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + #expect(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken) == nil) + #expect(connections.stats == .init(backingOff: 1, idle: 2, availableStreams: 2)) // connection three should be moved to connection one and for this reason become permanent - XCTAssertEqual(connections.backoffDone(firstRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken, thirdConnIdleTimerCancellationToken])) - XCTAssertEqual(connections.stats, .init(idle: 2, availableStreams: 2)) + #expect(connections.backoffDone(firstRequest.connectionID, retry: false) == .cancelTimers([backoffTimerCancellationToken, thirdConnIdleTimerCancellationToken])) + #expect(connections.stats == .init(idle: 2, availableStreams: 2)) - XCTAssertNil(connections.closeConnectionIfIdle(newThirdConnection.id)) + #expect(connections.closeConnectionIfIdle(newThirdConnection.id) == nil) } - func testBackoffDoneReturnsNilIfOverflowConnection() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testBackoffDoneReturnsNilIfOverflowConnection() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 0, @@ -229,33 +235,36 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { - return XCTFail("Expected to get two requests") + Issue.record("Expected to get two requests") + return } guard let secondRequest = connections.createNewDemandConnectionIfPossible() else { - return XCTFail("Expected to get another request") + Issue.record("Expected to get another request") + return } - XCTAssertEqual(connections.stats, .init(connecting: 2)) + #expect(connections.stats == .init(connecting: 2)) let newFirstConnection = MockConnection(id: firstRequest.connectionID) let (_, establishedFirstConnectionContext) = connections.newConnectionEstablished(newFirstConnection, maxStreams: 1) - XCTAssertEqual(establishedFirstConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedFirstConnectionContext.use, .demand) - XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 1, availableStreams: 1)) - XCTAssertEqual(connections.soonAvailableConnections, 1) + #expect(establishedFirstConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedFirstConnectionContext.use == .demand) + #expect(connections.stats == .init(connecting: 1, idle: 1, availableStreams: 1)) + #expect(connections.soonAvailableConnections == 1) let backoffTimer = connections.backoffNextConnectionAttempt(secondRequest.connectionID) let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) - XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 1, availableStreams: 1)) - XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + #expect(connections.stats == .init(backingOff: 1, idle: 1, availableStreams: 1)) + #expect(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken) == nil) - XCTAssertEqual(connections.backoffDone(secondRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken])) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(connections.backoffDone(secondRequest.connectionID, retry: false) == .cancelTimers([backoffTimerCancellationToken])) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) - XCTAssertNotNil(connections.closeConnectionIfIdle(newFirstConnection.id)) + #expect(connections.closeConnectionIfIdle(newFirstConnection.id) != nil) } - func testPingPong() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testPingPong() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 1, @@ -266,35 +275,40 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) let requests = connections.refillConnections() - XCTAssertFalse(connections.isEmpty) - XCTAssertEqual(connections.stats, .init(connecting: 1)) + #expect(!connections.isEmpty) + #expect(connections.stats == .init(connecting: 1)) - XCTAssertEqual(requests.count, 1) - guard let firstRequest = requests.first else { return XCTFail("Expected to have a request here") } + #expect(requests.count == 1) + guard let firstRequest = requests.first else { + Issue.record("Expected to have a request here") + return + } let newConnection = MockConnection(id: firstRequest.connectionID) let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) - XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedConnectionContext.use, .persisted) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(establishedConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedConnectionContext.use == .persisted) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) let timers = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertEqual(timers, [keepAliveTimer]) - XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(timers == [keepAliveTimer]) + #expect(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) - XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) - XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + #expect(keepAliveAction == .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + #expect(connections.stats == .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) guard let (_, afterPingIdleContext) = connections.keepAliveSucceeded(newConnection.id) else { - return XCTFail("Expected to receive an AvailableContext") + Issue.record("Expected to receive an AvailableContext") + return } - XCTAssertEqual(afterPingIdleContext.info, .idle(availableStreams: 1, newIdle: false)) - XCTAssertEqual(afterPingIdleContext.use, .persisted) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(afterPingIdleContext.info == .idle(availableStreams: 1, newIdle: false)) + #expect(afterPingIdleContext.use == .persisted) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) } - func testKeepAliveShouldNotIndicateCloseConnectionAfterClosed() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testKeepAliveShouldNotIndicateCloseConnectionAfterClosed() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 0, @@ -304,24 +318,28 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { keepAliveReducesAvailableStreams: true ) - guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { return XCTFail("Expected to have a request here") } + guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { + Issue.record("Expected to have a request here") + return + } let newConnection = MockConnection(id: firstRequest.connectionID) let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) - XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(establishedConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) _ = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) - XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) - XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + #expect(keepAliveAction == .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + #expect(connections.stats == .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) _ = connections.closeConnectionIfIdle(newConnection.id) guard connections.keepAliveFailed(newConnection.id) == nil else { - return XCTFail("Expected keepAliveFailed not to cause close again") + Issue.record("Expected keepAliveFailed not to cause close again") + return } - XCTAssertEqual(connections.stats, .init(closing: 1)) + #expect(connections.stats == .init(closing: 1)) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index bc4c2c4b..2d81cf38 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -1,35 +1,36 @@ @testable import _ConnectionPoolModule -import XCTest +import _ConnectionPoolTestUtils +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class PoolStateMachine_ConnectionStateTests: XCTestCase { +@Suite struct PoolStateMachine_ConnectionStateTests { typealias TestConnectionState = TestPoolStateMachine.ConnectionState - func testStartupLeaseReleaseParkLease() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testStartupLeaseReleaseParkLease() { let connectionID = 1 var state = TestConnectionState(id: connectionID) - XCTAssertEqual(state.id, connectionID) - XCTAssertEqual(state.isIdle, false) - XCTAssertEqual(state.isAvailable, false) - XCTAssertEqual(state.isConnected, false) - XCTAssertEqual(state.isLeased, false) + #expect(state.id == connectionID) + #expect(!state.isIdle) + #expect(!state.isAvailable) + #expect(!state.isConnected) + #expect(!state.isLeased) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(state.isIdle, true) - XCTAssertEqual(state.isAvailable, true) - XCTAssertEqual(state.isConnected, true) - XCTAssertEqual(state.isLeased, false) - XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) - - XCTAssertEqual(state.isIdle, false) - XCTAssertEqual(state.isAvailable, false) - XCTAssertEqual(state.isConnected, true) - XCTAssertEqual(state.isLeased, true) - - XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) + #expect(state.isIdle) + #expect(state.isAvailable) + #expect(state.isConnected) + #expect(state.isLeased == false) + #expect(state.lease(streams: 1) == .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + #expect(!state.isIdle) + #expect(!state.isAvailable) + #expect(state.isConnected) + #expect(state.isLeased) + + #expect(state.release(streams: 1) == .idle(availableStreams: 1, newIdle: true)) let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) - XCTAssert( + #expect( parkResult.elementsEqual([ .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) @@ -37,31 +38,33 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { ) guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { - return XCTFail("Expected to get two timers") + Issue.record("Expected to get two timers") + return } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) - XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) + #expect(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken) == nil) let expectLeaseAction = TestConnectionState.LeaseAction( connection: connection, timersToCancel: [idleTimerCancellationToken, keepAliveTimerCancellationToken], wasIdle: true ) - XCTAssertEqual(state.lease(streams: 1), expectLeaseAction) + #expect(state.lease(streams: 1) == expectLeaseAction) } - func testStartupParkLeaseBeforeTimersRegistered() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testStartupParkLeaseBeforeTimersRegistered() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) - XCTAssertEqual( - parkResult, + #expect( + parkResult == [ .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) @@ -69,24 +72,26 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { ) guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { - return XCTFail("Expected to get two timers") + Issue.record("Expected to get two timers") + return } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) - XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + #expect(state.lease(streams: 1) == .init(connection: connection, timersToCancel: .init(), wasIdle: true)) - XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken), keepAliveTimerCancellationToken) - XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken), idleTimerCancellationToken) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == keepAliveTimerCancellationToken) + #expect(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken) == idleTimerCancellationToken) } - func testStartupParkLeasePark() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testStartupParkLeasePark() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) - XCTAssert( + #expect( parkResult.elementsEqual([ .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) @@ -94,171 +99,186 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { ) guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { - return XCTFail("Expected to get two timers") + Issue.record("Expected to get two timers") + return } let initialKeepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) let initialIdleTimerCancellationToken = MockTimerCancellationToken(idleTimer) - XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + #expect(state.lease(streams: 1) == .init(connection: connection, timersToCancel: .init(), wasIdle: true)) - XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual( - state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true), + #expect(state.release(streams: 1) == .idle(availableStreams: 1, newIdle: true)) + #expect( + state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) == [ .init(timerID: 2, connectionID: connectionID, usecase: .keepAlive), .init(timerID: 3, connectionID: connectionID, usecase: .idleTimeout) ] ) - XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: initialKeepAliveTimerCancellationToken), initialKeepAliveTimerCancellationToken) - XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: initialIdleTimerCancellationToken), initialIdleTimerCancellationToken) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: initialKeepAliveTimerCancellationToken) == initialKeepAliveTimerCancellationToken) + #expect(state.timerScheduled(idleTimer, cancelContinuation: initialIdleTimerCancellationToken) == initialIdleTimerCancellationToken) } - func testStartupFailed() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testStartupFailed() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let firstBackoffTimer = state.failedToConnect() let firstBackoffTimerCancellationToken = MockTimerCancellationToken(firstBackoffTimer) - XCTAssertNil(state.timerScheduled(firstBackoffTimer, cancelContinuation: firstBackoffTimerCancellationToken)) - XCTAssertEqual(state.retryConnect(), firstBackoffTimerCancellationToken) + #expect(state.timerScheduled(firstBackoffTimer, cancelContinuation: firstBackoffTimerCancellationToken) == nil) + #expect(state.retryConnect() == firstBackoffTimerCancellationToken) let secondBackoffTimer = state.failedToConnect() let secondBackoffTimerCancellationToken = MockTimerCancellationToken(secondBackoffTimer) - XCTAssertNil(state.retryConnect()) - XCTAssertEqual( - state.timerScheduled(secondBackoffTimer, cancelContinuation: secondBackoffTimerCancellationToken), + #expect(state.retryConnect() == nil) + #expect( + state.timerScheduled(secondBackoffTimer, cancelContinuation: secondBackoffTimerCancellationToken) == secondBackoffTimerCancellationToken ) let thirdBackoffTimer = state.failedToConnect() let thirdBackoffTimerCancellationToken = MockTimerCancellationToken(thirdBackoffTimer) - XCTAssertNil(state.retryConnect()) + #expect(state.retryConnect() == nil) let forthBackoffTimer = state.failedToConnect() let forthBackoffTimerCancellationToken = MockTimerCancellationToken(forthBackoffTimer) - XCTAssertEqual( - state.timerScheduled(thirdBackoffTimer, cancelContinuation: thirdBackoffTimerCancellationToken), + #expect( + state.timerScheduled(thirdBackoffTimer, cancelContinuation: thirdBackoffTimerCancellationToken) == thirdBackoffTimerCancellationToken ) - XCTAssertNil( - state.timerScheduled(forthBackoffTimer, cancelContinuation: forthBackoffTimerCancellationToken) + #expect( + state.timerScheduled(forthBackoffTimer, cancelContinuation: forthBackoffTimerCancellationToken) == nil ) - XCTAssertEqual(state.retryConnect(), forthBackoffTimerCancellationToken) + #expect(state.retryConnect() == forthBackoffTimerCancellationToken) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) } - func testLeaseMultipleStreams() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeaseMultipleStreams() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + #expect(state.connected(connection, maxStreams: 100) == .idle(availableStreams: 100, newIdle: true)) let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) - guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + guard let keepAliveTimer = timers.first else { + Issue.record("Expected to get a keepAliveTimer") + return + } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) - XCTAssertEqual( - state.lease(streams: 30), + #expect( + state.lease(streams: 30) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [keepAliveTimerCancellationToken], wasIdle: true) ) - XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) + #expect(state.release(streams: 10) == .leased(availableStreams: 80)) - XCTAssertEqual( - state.lease(streams: 40), + #expect( + state.lease(streams: 40) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) ) - XCTAssertEqual( - state.lease(streams: 40), + #expect( + state.lease(streams: 40) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) ) - XCTAssertEqual(state.release(streams: 1), .leased(availableStreams: 1)) - XCTAssertEqual(state.release(streams: 98), .leased(availableStreams: 99)) - XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 100, newIdle: true)) + #expect(state.release(streams: 1) == .leased(availableStreams: 1)) + #expect(state.release(streams: 98) == .leased(availableStreams: 99)) + #expect(state.release(streams: 1) == .idle(availableStreams: 100, newIdle: true)) } - func testRunningKeepAliveReducesAvailableStreams() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRunningKeepAliveReducesAvailableStreams() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + #expect(state.connected(connection, maxStreams: 100) == .idle(availableStreams: 100, newIdle: true)) let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) - guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + guard let keepAliveTimer = timers.first else { + Issue.record("Expected to get a keepAliveTimer") + return + } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) - XCTAssertEqual( - state.runKeepAliveIfIdle(reducesAvailableStreams: true), + #expect( + state.runKeepAliveIfIdle(reducesAvailableStreams: true) == .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) ) - XCTAssertEqual( - state.lease(streams: 30), + #expect( + state.lease(streams: 30) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) ) - XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 79)) - XCTAssertEqual(state.isAvailable, true) - XCTAssertEqual( - state.lease(streams: 79), + #expect(state.release(streams: 10) == .leased(availableStreams: 79)) + #expect(state.isAvailable) + #expect( + state.lease(streams: 79) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) ) - XCTAssertEqual(state.isAvailable, false) - XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 1)) - XCTAssertEqual(state.isAvailable, true) + #expect(!state.isAvailable) + #expect(state.keepAliveSucceeded() == .leased(availableStreams: 1)) + #expect(state.isAvailable) } - func testRunningKeepAliveDoesNotReduceAvailableStreams() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRunningKeepAliveDoesNotReduceAvailableStreams() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + #expect(state.connected(connection, maxStreams: 100) == .idle(availableStreams: 100, newIdle: true)) let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) - guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + guard let keepAliveTimer = timers.first else { + Issue.record("Expected to get a keepAliveTimer") + return + } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) - XCTAssertEqual( - state.runKeepAliveIfIdle(reducesAvailableStreams: false), + #expect( + state.runKeepAliveIfIdle(reducesAvailableStreams: false) == .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) ) - XCTAssertEqual( - state.lease(streams: 30), + #expect( + state.lease(streams: 30) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) ) - XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) - XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 80)) + #expect(state.release(streams: 10) == .leased(availableStreams: 80)) + #expect(state.keepAliveSucceeded() == .leased(availableStreams: 80)) } - func testRunKeepAliveRacesAgainstIdleClose() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRunKeepAliveRacesAgainstIdleClose() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { - return XCTFail("Expected to get two timers") + Issue.record("Expected to get two timers") + return } - XCTAssertEqual(keepAliveTimer, .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive)) - XCTAssertEqual(idleTimer, .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout)) + #expect(keepAliveTimer == .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive)) + #expect(idleTimer == .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout)) let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) - XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) - - XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1, runningKeepAlive: false)) - XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) + #expect(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken) == nil) + #expect(state.closeIfIdle() == .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1, runningKeepAlive: false)) + #expect(state.runKeepAliveIfIdle(reducesAvailableStreams: true) == .none) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index 0231da51..458c6b3f 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -1,147 +1,152 @@ @testable import _ConnectionPoolModule -import XCTest +import _ConnectionPoolTestUtils +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class PoolStateMachine_RequestQueueTests: XCTestCase { +@Suite struct PoolStateMachine_RequestQueueTests { typealias TestQueue = TestPoolStateMachine.RequestQueue - func testHappyPath() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testHappyPath() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - XCTAssertEqual(queue.count, 1) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 1) + #expect(!queue.isEmpty) let popResult = queue.pop(max: 3) - XCTAssert(popResult.elementsEqual([request1])) - XCTAssert(queue.isEmpty) - XCTAssertEqual(queue.count, 0) + #expect(popResult.elementsEqual([request1])) + #expect(queue.isEmpty) + #expect(queue.count == 0) } - func testEnqueueAndPopMultipleRequests() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testEnqueueAndPopMultipleRequests() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { - XCTAssertEqual(queue.count, 3) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 3) + #expect(!queue.isEmpty) let popResult = queue.pop(max: 3) - XCTAssert(popResult.elementsEqual([request1, request2, request3])) - XCTAssert(queue.isEmpty) - XCTAssertEqual(queue.count, 0) + #expect(popResult.elementsEqual([request1, request2, request3])) + #expect(queue.isEmpty) + #expect(queue.count == 0) } - XCTAssert(isKnownUniquelyReferenced(&request1)) - XCTAssert(isKnownUniquelyReferenced(&request2)) - XCTAssert(isKnownUniquelyReferenced(&request3)) + #expect(isKnownUniquelyReferenced(&request1)) + #expect(isKnownUniquelyReferenced(&request2)) + #expect(isKnownUniquelyReferenced(&request3)) } - func testEnqueueAndPopOnlyOne() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testEnqueueAndPopOnlyOne() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { - XCTAssertEqual(queue.count, 3) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 3) + #expect(!queue.isEmpty) let popResult = queue.pop(max: 1) - XCTAssert(popResult.elementsEqual([request1])) - XCTAssertFalse(queue.isEmpty) - XCTAssertEqual(queue.count, 2) + #expect(popResult.elementsEqual([request1])) + #expect(!queue.isEmpty) + #expect(queue.count == 2) let removeAllResult = queue.removeAll() - XCTAssert(Set(removeAllResult) == [request2, request3]) + #expect(Set(removeAllResult) == [request2, request3]) } - XCTAssert(isKnownUniquelyReferenced(&request1)) - XCTAssert(isKnownUniquelyReferenced(&request2)) - XCTAssert(isKnownUniquelyReferenced(&request3)) + #expect(isKnownUniquelyReferenced(&request1)) + #expect(isKnownUniquelyReferenced(&request2)) + #expect(isKnownUniquelyReferenced(&request3)) } - func testCancellation() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testCancellation() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { - XCTAssertEqual(queue.count, 3) + #expect(queue.count == 3) let returnedRequest2 = queue.remove(request2.id) - XCTAssert(returnedRequest2 === request2) - XCTAssertEqual(queue.count, 2) - XCTAssertFalse(queue.isEmpty) + #expect(returnedRequest2 === request2) + #expect(queue.count == 2) + #expect(!queue.isEmpty) } // still retained by the deque inside the queue - XCTAssertEqual(queue.requests.count, 2) - XCTAssertEqual(queue.queue.count, 3) + #expect(queue.requests.count == 2) + #expect(queue.queue.count == 3) do { - XCTAssertEqual(queue.count, 2) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 2) + #expect(!queue.isEmpty) let popResult = queue.pop(max: 3) - XCTAssert(popResult.elementsEqual([request1, request3])) - XCTAssert(queue.isEmpty) - XCTAssertEqual(queue.count, 0) + #expect(popResult.elementsEqual([request1, request3])) + #expect(queue.isEmpty) + #expect(queue.count == 0) } - XCTAssert(isKnownUniquelyReferenced(&request1)) - XCTAssert(isKnownUniquelyReferenced(&request2)) - XCTAssert(isKnownUniquelyReferenced(&request3)) + #expect(isKnownUniquelyReferenced(&request1)) + #expect(isKnownUniquelyReferenced(&request2)) + #expect(isKnownUniquelyReferenced(&request3)) } - func testRemoveAllAfterCancellation() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRemoveAllAfterCancellation() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { - XCTAssertEqual(queue.count, 3) + #expect(queue.count == 3) let returnedRequest2 = queue.remove(request2.id) - XCTAssert(returnedRequest2 === request2) - XCTAssertEqual(queue.count, 2) - XCTAssertFalse(queue.isEmpty) + #expect(returnedRequest2 === request2) + #expect(queue.count == 2) + #expect(!queue.isEmpty) } // still retained by the deque inside the queue - XCTAssertEqual(queue.requests.count, 2) - XCTAssertEqual(queue.queue.count, 3) + #expect(queue.requests.count == 2) + #expect(queue.queue.count == 3) do { - XCTAssertEqual(queue.count, 2) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 2) + #expect(!queue.isEmpty) let removeAllResult = queue.removeAll() - XCTAssert(Set(removeAllResult) == [request1, request3]) - XCTAssert(queue.isEmpty) - XCTAssertEqual(queue.count, 0) + #expect(Set(removeAllResult) == [request1, request3]) + #expect(queue.isEmpty) + #expect(queue.count == 0) } - XCTAssert(isKnownUniquelyReferenced(&request1)) - XCTAssert(isKnownUniquelyReferenced(&request2)) - XCTAssert(isKnownUniquelyReferenced(&request3)) + #expect(isKnownUniquelyReferenced(&request1)) + #expect(isKnownUniquelyReferenced(&request2)) + #expect(isKnownUniquelyReferenced(&request3)) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index 2f3ae617..c748de28 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,20 +1,21 @@ -import XCTest @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import Testing @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) typealias TestPoolStateMachine = PoolStateMachine< MockConnection, ConnectionIDGenerator, MockConnection.ID, - MockRequest, - MockRequest.ID, + MockRequest, + MockRequest.ID, MockTimerCancellationToken > -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class PoolStateMachineTests: XCTestCase { +@Suite struct PoolStateMachineTests { - func testConnectionsAreCreatedAndParkedOnStartup() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testConnectionsAreCreatedAndParkedOnStartup() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 2 configuration.maximumConnectionSoftLimit = 4 @@ -32,25 +33,26 @@ final class PoolStateMachineTests: XCTestCase { do { let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 2) + #expect(requests.count == 2) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) let connection1KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 0, usecase: .keepAlive), duration: .seconds(10)) let connection1KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection1KeepAliveTimer) - XCTAssertEqual(createdAction1.request, .none) - XCTAssertEqual(createdAction1.connection, .scheduleTimers([connection1KeepAliveTimer])) + #expect(createdAction1.request == .none) + #expect(createdAction1.connection == .scheduleTimers([connection1KeepAliveTimer])) - XCTAssertEqual(stateMachine.timerScheduled(connection1KeepAliveTimer, cancelContinuation: connection1KeepAliveTimerCancellationToken), .none) + #expect(stateMachine.timerScheduled(connection1KeepAliveTimer, cancelContinuation: connection1KeepAliveTimerCancellationToken) == .none) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) let connection2KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: .seconds(10)) let connection2KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection2KeepAliveTimer) - XCTAssertEqual(createdAction2.request, .none) - XCTAssertEqual(createdAction2.connection, .scheduleTimers([connection2KeepAliveTimer])) - XCTAssertEqual(stateMachine.timerScheduled(connection2KeepAliveTimer, cancelContinuation: connection2KeepAliveTimerCancellationToken), .none) + #expect(createdAction2.request == .none) + #expect(createdAction2.connection == .scheduleTimers([connection2KeepAliveTimer])) + #expect(stateMachine.timerScheduled(connection2KeepAliveTimer, cancelContinuation: connection2KeepAliveTimerCancellationToken) == .none) } } - func testConnectionsNoKeepAliveRun() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testConnectionsNoKeepAliveRun() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 1 configuration.maximumConnectionSoftLimit = 4 @@ -68,51 +70,52 @@ final class PoolStateMachineTests: XCTestCase { // refill pool to at least one connection let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 1) + #expect(requests.count == 1) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .none) - XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + #expect(createdAction1.request == .none) + #expect(createdAction1.connection == .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) - XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + #expect(leaseRequest1.connection == .cancelTimers([])) + #expect(leaseRequest1.request == .leaseConnection(.init(element: request1), connection1)) // release connection 1 - XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) + #expect(stateMachine.releaseConnection(connection1, streams: 1) == .none()) // lease connection 1 - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) - XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) - XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) + #expect(leaseRequest2.connection == .cancelTimers([])) + #expect(leaseRequest2.request == .leaseConnection(.init(element: request2), connection1)) // request connection while none is available - let request3 = MockRequest() + let request3 = MockRequest(connectionType: MockConnection.self) let leaseRequest3 = stateMachine.leaseConnection(request3) - XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) - XCTAssertEqual(leaseRequest3.request, .none) + #expect(leaseRequest3.connection == .makeConnection(.init(connectionID: 1), [])) + #expect(leaseRequest3.request == .none) // make connection 2 and lease immediately let connection2 = MockConnection(id: 1) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) - XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request3), connection2)) - XCTAssertEqual(createdAction2.connection, .none) + #expect(createdAction2.request == .leaseConnection(.init(element: request3), connection2)) + #expect(createdAction2.connection == .none) // release connection 2 let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) - XCTAssertEqual( - stateMachine.releaseConnection(connection2, streams: 1), + #expect( + stateMachine.releaseConnection(connection2, streams: 1) == .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) ) - XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) - XCTAssertEqual(stateMachine.timerTriggered(connection2IdleTimer), .init(request: .none, connection: .closeConnection(connection2, [connection2IdleTimerCancellationToken]))) + #expect(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken) == .none) + #expect(stateMachine.timerTriggered(connection2IdleTimer) == .init(request: .none, connection: .closeConnection(connection2, [connection2IdleTimerCancellationToken]))) } - func testOnlyOverflowConnections() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testOnlyOverflowConnections() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 0 configuration.maximumConnectionSoftLimit = 0 @@ -128,49 +131,50 @@ final class PoolStateMachineTests: XCTestCase { // don't refill pool let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 0) + #expect(requests.count == 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) - XCTAssertEqual(leaseRequest1.request, .none) + #expect(leaseRequest1.connection == .makeConnection(.init(connectionID: 0), [])) + #expect(leaseRequest1.request == .none) // make connection 1 and lease immediately let connection1 = MockConnection(id: 0) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) - XCTAssertEqual(createdAction1.connection, .none) + #expect(createdAction1.request == .leaseConnection(.init(element: request1), connection1)) + #expect(createdAction1.connection == .none) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) - XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) - XCTAssertEqual(leaseRequest2.request, .none) + #expect(leaseRequest2.connection == .makeConnection(.init(connectionID: 1), [])) + #expect(leaseRequest2.request == .none) // release connection 1 should be leased again immediately let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) - XCTAssertEqual(releaseRequest1.request, .leaseConnection(.init(element: request2), connection1)) - XCTAssertEqual(releaseRequest1.connection, .none) + #expect(releaseRequest1.request == .leaseConnection(.init(element: request2), connection1)) + #expect(releaseRequest1.connection == .none) // connection 2 comes up and should be closed right away let connection2 = MockConnection(id: 1) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) - XCTAssertEqual(createdAction2.request, .none) - XCTAssertEqual(createdAction2.connection, .closeConnection(connection2, [])) - XCTAssertEqual(stateMachine.connectionClosed(connection2), .none()) + #expect(createdAction2.request == .none) + #expect(createdAction2.connection == .closeConnection(connection2, [])) + #expect(stateMachine.connectionClosed(connection2) == .none()) // release connection 1 should be closed as well let releaseRequest2 = stateMachine.releaseConnection(connection1, streams: 1) - XCTAssertEqual(releaseRequest2.request, .none) - XCTAssertEqual(releaseRequest2.connection, .closeConnection(connection1, [])) + #expect(releaseRequest2.request == .none) + #expect(releaseRequest2.connection == .closeConnection(connection1, [])) let shutdownAction = stateMachine.triggerForceShutdown() - XCTAssertEqual(shutdownAction.request, .failRequests(.init(), .poolShutdown)) - XCTAssertEqual(shutdownAction.connection, .shutdown(.init())) + #expect(shutdownAction.request == .failRequests(.init(), .poolShutdown)) + #expect(shutdownAction.connection == .shutdown(.init())) } - func testDemandConnectionIsMadePermanentIfPermanentIsClose() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testDemandConnectionIsMadePermanentIfPermanentIsClose() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 1 configuration.maximumConnectionSoftLimit = 2 @@ -188,44 +192,45 @@ final class PoolStateMachineTests: XCTestCase { // refill pool to at least one connection let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 1) + #expect(requests.count == 1) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .none) - XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + #expect(createdAction1.request == .none) + #expect(createdAction1.connection == .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) - XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + #expect(leaseRequest1.connection == .cancelTimers([])) + #expect(leaseRequest1.request == .leaseConnection(.init(element: request1), connection1)) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) - XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) - XCTAssertEqual(leaseRequest2.request, .none) + #expect(leaseRequest2.connection == .makeConnection(.init(connectionID: 1), [])) + #expect(leaseRequest2.request == .none) // make connection 2 and lease immediately let connection2 = MockConnection(id: 1) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) - XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) - XCTAssertEqual(createdAction2.connection, .none) + #expect(createdAction2.request == .leaseConnection(.init(element: request2), connection2)) + #expect(createdAction2.connection == .none) // release connection 2 let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) - XCTAssertEqual( - stateMachine.releaseConnection(connection2, streams: 1), + #expect( + stateMachine.releaseConnection(connection2, streams: 1) == .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) ) - XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) + #expect(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken) == .none) // connection 1 is dropped - XCTAssertEqual(stateMachine.connectionClosed(connection1), .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) + #expect(stateMachine.connectionClosed(connection1) == .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) } - func testReleaseLoosesRaceAgainstClosed() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testReleaseLoosesRaceAgainstClosed() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 0 configuration.maximumConnectionSoftLimit = 2 @@ -241,32 +246,33 @@ final class PoolStateMachineTests: XCTestCase { // don't refill pool let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 0) + #expect(requests.count == 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) - XCTAssertEqual(leaseRequest1.request, .none) + #expect(leaseRequest1.connection == .makeConnection(.init(connectionID: 0), [])) + #expect(leaseRequest1.request == .none) // make connection 1 and lease immediately let connection1 = MockConnection(id: 0) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) - XCTAssertEqual(createdAction1.connection, .none) + #expect(createdAction1.request == .leaseConnection(.init(element: request1), connection1)) + #expect(createdAction1.connection == .none) // connection got closed let closedAction = stateMachine.connectionClosed(connection1) - XCTAssertEqual(closedAction.connection, .none) - XCTAssertEqual(closedAction.request, .none) + #expect(closedAction.connection == .none) + #expect(closedAction.request == .none) // release connection 1 should be leased again immediately let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) - XCTAssertEqual(releaseRequest1.request, .none) - XCTAssertEqual(releaseRequest1.connection, .none) + #expect(releaseRequest1.request == .none) + #expect(releaseRequest1.connection == .none) } - func testKeepAliveOnClosingConnection() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testKeepAliveOnClosingConnection() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 0 configuration.maximumConnectionSoftLimit = 2 @@ -274,7 +280,6 @@ final class PoolStateMachineTests: XCTestCase { configuration.keepAliveDuration = .seconds(2) configuration.idleTimeoutDuration = .seconds(4) - var stateMachine = TestPoolStateMachine( configuration: configuration, generator: .init(), @@ -283,46 +288,46 @@ final class PoolStateMachineTests: XCTestCase { // don't refill pool let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 0) + #expect(requests.count == 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) - XCTAssertEqual(leaseRequest1.request, .none) + #expect(leaseRequest1.connection == .makeConnection(.init(connectionID: 0), [])) + #expect(leaseRequest1.request == .none) // make connection 1 let connection1 = MockConnection(id: 0) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) - XCTAssertEqual(createdAction1.connection, .none) + #expect(createdAction1.request == .leaseConnection(.init(element: request1), connection1)) + #expect(createdAction1.connection == .none) _ = stateMachine.releaseConnection(connection1, streams: 1) // trigger keep alive let keepAliveAction1 = stateMachine.connectionKeepAliveTimerTriggered(connection1.id) - XCTAssertEqual(keepAliveAction1.connection, .runKeepAlive(connection1, nil)) + #expect(keepAliveAction1.connection == .runKeepAlive(connection1, nil)) // fail keep alive and cause closed let keepAliveFailed1 = stateMachine.connectionKeepAliveFailed(connection1.id) - XCTAssertEqual(keepAliveFailed1.connection, .closeConnection(connection1, [])) + #expect(keepAliveFailed1.connection == .closeConnection(connection1, [])) connection1.closeIfClosing() // request connection while none exists anymore - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) - XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) - XCTAssertEqual(leaseRequest2.request, .none) + #expect(leaseRequest2.connection == .makeConnection(.init(connectionID: 1), [])) + #expect(leaseRequest2.request == .none) // make connection 2 let connection2 = MockConnection(id: 1) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) - XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) - XCTAssertEqual(createdAction2.connection, .none) + #expect(createdAction2.request == .leaseConnection(.init(element: request2), connection2)) + #expect(createdAction2.connection == .none) _ = stateMachine.releaseConnection(connection2, streams: 1) // trigger keep alive while connection is still open let keepAliveAction2 = stateMachine.connectionKeepAliveTimerTriggered(connection2.id) - XCTAssertEqual(keepAliveAction2.connection, .runKeepAlive(connection2, nil)) + #expect(keepAliveAction2.connection == .runKeepAlive(connection2, nil)) // close connection in the middle of keep alive connection2.close() @@ -330,10 +335,11 @@ final class PoolStateMachineTests: XCTestCase { // fail keep alive and cause closed let keepAliveFailed2 = stateMachine.connectionKeepAliveFailed(connection2.id) - XCTAssertEqual(keepAliveFailed2.connection, .closeConnection(connection2, [])) + #expect(keepAliveFailed2.connection == .closeConnection(connection2, [])) } - func testConnectionIsEstablishedAfterFailedKeepAliveIfNotEnoughConnectionsLeft() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testConnectionIsEstablishedAfterFailedKeepAliveIfNotEnoughConnectionsLeft() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 1 configuration.maximumConnectionSoftLimit = 2 @@ -350,35 +356,38 @@ final class PoolStateMachineTests: XCTestCase { // refill pool let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 1) + #expect(requests.count == 1) // one connection should exist - let request = MockRequest() + let request = MockRequest(connectionType: MockConnection.self) let leaseRequest = stateMachine.leaseConnection(request) - XCTAssertEqual(leaseRequest.connection, .none) - XCTAssertEqual(leaseRequest.request, .none) + #expect(leaseRequest.connection == .none) + #expect(leaseRequest.request == .none) // make connection 1 let connection = MockConnection(id: 0) let createdAction = stateMachine.connectionEstablished(connection, maxStreams: 1) - XCTAssertEqual(createdAction.request, .leaseConnection(.init(element: request), connection)) - XCTAssertEqual(createdAction.connection, .none) + #expect(createdAction.request == .leaseConnection(.init(element: request), connection)) + #expect(createdAction.connection == .none) _ = stateMachine.releaseConnection(connection, streams: 1) // trigger keep alive let keepAliveAction = stateMachine.connectionKeepAliveTimerTriggered(connection.id) - XCTAssertEqual(keepAliveAction.connection, .runKeepAlive(connection, nil)) + #expect(keepAliveAction.connection == .runKeepAlive(connection, nil)) // fail keep alive, cause closed and make new connection let keepAliveFailed = stateMachine.connectionKeepAliveFailed(connection.id) - XCTAssertEqual(keepAliveFailed.connection, .closeConnection(connection, [])) + #expect(keepAliveFailed.connection == .closeConnection(connection, [])) let connectionClosed = stateMachine.connectionClosed(connection) - XCTAssertEqual(connectionClosed.connection, .makeConnection(.init(connectionID: 1), [])) + #expect(connectionClosed.connection == .makeConnection(.init(connectionID: 1), [])) connection.closeIfClosing() let establishAction = stateMachine.connectionEstablished(.init(id: 1), maxStreams: 1) - XCTAssertEqual(establishAction.request, .none) - guard case .scheduleTimers(let timers) = establishAction.connection else { return XCTFail("Unexpected connection action") } - XCTAssertEqual(timers, [.init(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: configuration.keepAliveDuration!)]) + #expect(establishAction.request == .none) + if case .scheduleTimers(let timers) = establishAction.connection { + #expect(timers == [.init(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: configuration.keepAliveDuration!)]) + } else { + Issue.record("Unexpected connection action") + } } } diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index 1a2836b9..9dfac549 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -1,72 +1,84 @@ @testable import _ConnectionPoolModule -import XCTest +import Testing -final class TinyFastSequenceTests: XCTestCase { - func testCountIsEmptyAndIterator() async { +@Suite struct TinyFastSequenceTests { + @Test func testCountIsEmptyAndIterator() { var sequence = TinyFastSequence() - XCTAssertEqual(sequence.count, 0) - XCTAssertEqual(sequence.isEmpty, true) - XCTAssertEqual(sequence.first, nil) - XCTAssertEqual(Array(sequence), []) + #expect(sequence.count == 0) + #expect(sequence.isEmpty == true) + #expect(sequence.first == nil) + #expect(Array(sequence) == []) sequence.append(1) - XCTAssertEqual(sequence.count, 1) - XCTAssertEqual(sequence.isEmpty, false) - XCTAssertEqual(sequence.first, 1) - XCTAssertEqual(Array(sequence), [1]) + #expect(sequence.count == 1) + #expect(sequence.isEmpty == false) + #expect(sequence.first == 1) + #expect(Array(sequence) == [1]) sequence.append(2) - XCTAssertEqual(sequence.count, 2) - XCTAssertEqual(sequence.isEmpty, false) - XCTAssertEqual(sequence.first, 1) - XCTAssertEqual(Array(sequence), [1, 2]) + #expect(sequence.count == 2) + #expect(sequence.isEmpty == false) + #expect(sequence.first == 1) + #expect(Array(sequence) == [1, 2]) sequence.append(3) - XCTAssertEqual(sequence.count, 3) - XCTAssertEqual(sequence.isEmpty, false) - XCTAssertEqual(sequence.first, 1) - XCTAssertEqual(Array(sequence), [1, 2, 3]) + #expect(sequence.count == 3) + #expect(sequence.isEmpty == false) + #expect(sequence.first == 1) + #expect(Array(sequence) == [1, 2, 3]) } - func testReserveCapacityIsForwarded() { + @Test func testReserveCapacityIsForwarded() { var emptySequence = TinyFastSequence() emptySequence.reserveCapacity(8) emptySequence.append(1) emptySequence.append(2) emptySequence.append(3) guard case .n(let array) = emptySequence.base else { - return XCTFail("Expected sequence to be backed by an array") + Issue.record("Expected sequence to be backed by an array") + return } - XCTAssertEqual(array.capacity, 8) + #expect(array.capacity >= 8) var oneElemSequence = TinyFastSequence(element: 1) oneElemSequence.reserveCapacity(8) oneElemSequence.append(2) oneElemSequence.append(3) guard case .n(let array) = oneElemSequence.base else { - return XCTFail("Expected sequence to be backed by an array") + Issue.record("Expected sequence to be backed by an array") + return } - XCTAssertEqual(array.capacity, 8) + #expect(array.capacity >= 8) var twoElemSequence = TinyFastSequence([1, 2]) twoElemSequence.reserveCapacity(8) + twoElemSequence.append(3) guard case .n(let array) = twoElemSequence.base else { - return XCTFail("Expected sequence to be backed by an array") + Issue.record("Expected sequence to be backed by an array") + return } - XCTAssertEqual(array.capacity, 8) + #expect(array.capacity >= 8) + + var threeElemSequence = TinyFastSequence([1, 2, 3]) + threeElemSequence.reserveCapacity(8) + guard case .n(let array) = threeElemSequence.base else { + Issue.record("Expected sequence to be backed by an array") + return + } + #expect(array.capacity >= 8) } - func testNewSequenceSlowPath() { + @Test func testNewSequenceSlowPath() { let sequence = TinyFastSequence("AB".utf8) - XCTAssertEqual(Array(sequence), [UInt8(ascii: "A"), UInt8(ascii: "B")]) + #expect(Array(sequence) == [UInt8(ascii: "A"), UInt8(ascii: "B")]) } - func testSingleItem() { + @Test func testSingleItem() { let sequence = TinyFastSequence("A".utf8) - XCTAssertEqual(Array(sequence), [UInt8(ascii: "A")]) + #expect(Array(sequence) == [UInt8(ascii: "A")]) } - func testEmptyCollection() { + @Test func testEmptyCollection() { let sequence = TinyFastSequence("".utf8) - XCTAssertTrue(sequence.isEmpty) - XCTAssertEqual(sequence.count, 0) - XCTAssertEqual(Array(sequence), []) + #expect(sequence.isEmpty == true) + #expect(sequence.count == 0) + #expect(Array(sequence) == []) } } diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 513157fd..b4c8e93f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -476,6 +476,87 @@ final class AsyncPostgresConnectionTests: XCTestCase { XCTFail("Unexpected error: \(String(describing: error))") } } + + static let preparedStatementWithOptionalTestTable = "AsyncTestPreparedStatementWithOptionalTestTable" + func testPreparedStatementWithOptionalBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID? + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID?) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID?).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementWithOptionalTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(String(describing: uuid))" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementWithOptionalTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } } extension XCTestCase { diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 913d91b2..d541899b 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -123,6 +123,25 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(foo, "hello") } + func testQueryNothing() throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var _result: PostgresQueryResult? + XCTAssertNoThrow(_result = try conn?.query(""" + -- Some comments + """, logger: .psqlTest).wait()) + + let result = try XCTUnwrap(_result) + XCTAssertEqual(result.rows, []) + XCTAssertEqual(result.metadata.command, "") + } + func testDecodeIntegers() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -359,4 +378,5 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(obj?.bar, 2) } } + } diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index d6d89dc3..9ac92754 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -42,6 +42,145 @@ final class PostgresClientTests: XCTestCase { taskGroup.cancelAll() } } + + func testTransaction() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let tableName = "test_client_transactions" + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + do { + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + try await client.query( + """ + CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + uuid UUID NOT NULL + ); + """, + logger: logger + ) + + let iterations = 1000 + + for _ in 0...psqlArrayType, .int4RangeArray) - XCTAssertEqual(Range.psqlType, .int4Range) - XCTAssertEqual([Range].psqlType, .int4RangeArray) + #expect(Range.psqlArrayType == .int4RangeArray) + #expect(Range.psqlType == .int4Range) + #expect([Range].psqlType == .int4RangeArray) - XCTAssertEqual(ClosedRange.psqlArrayType, .int4RangeArray) - XCTAssertEqual(ClosedRange.psqlType, .int4Range) - XCTAssertEqual([ClosedRange].psqlType, .int4RangeArray) + #expect(ClosedRange.psqlArrayType == .int4RangeArray) + #expect(ClosedRange.psqlType == .int4Range) + #expect([ClosedRange].psqlType == .int4RangeArray) - XCTAssertEqual(Range.psqlArrayType, .int8RangeArray) - XCTAssertEqual(Range.psqlType, .int8Range) - XCTAssertEqual([Range].psqlType, .int8RangeArray) + #expect(Range.psqlArrayType == .int8RangeArray) + #expect(Range.psqlType == .int8Range) + #expect([Range].psqlType == .int8RangeArray) - XCTAssertEqual(ClosedRange.psqlArrayType, .int8RangeArray) - XCTAssertEqual(ClosedRange.psqlType, .int8Range) - XCTAssertEqual([ClosedRange].psqlType, .int8RangeArray) + #expect(ClosedRange.psqlArrayType == .int8RangeArray) + #expect(ClosedRange.psqlType == .int8Range) + #expect([ClosedRange].psqlType == .int8RangeArray) } - func testStringArrayRoundTrip() { + @Test func testStringArrayRoundTrip() { let values = ["foo", "bar", "hello", "world"] var buffer = ByteBuffer() values.encode(into: &buffer, context: .default) var result: [String]? - XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) - XCTAssertEqual(values, result) + #expect(throws: Never.self) { + result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default) + } + #expect(values == result) } - func testEmptyStringArrayRoundTrip() { + @Test func testEmptyStringArrayRoundTrip() { let values: [String] = [] var buffer = ByteBuffer() values.encode(into: &buffer, context: .default) var result: [String]? - XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) - XCTAssertEqual(values, result) + #expect(throws: Never.self) { + result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default) + } + #expect(values == result) } - func testDecodeFailureIsNotEmptyOutOfScope() { + @Test func testDecodeFailureIsNotEmptyOutOfScope() { var buffer = ByteBuffer() buffer.writeInteger(Int32(2)) // invalid value buffer.writeInteger(Int32(0)) buffer.writeInteger(String.psqlType.rawValue) - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeFailureSecondValueIsUnexpected() { + @Test func testDecodeFailureSecondValueIsUnexpected() { var buffer = ByteBuffer() buffer.writeInteger(Int32(0)) // is empty buffer.writeInteger(Int32(1)) // invalid value, must always be 0 buffer.writeInteger(String.psqlType.rawValue) - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeFailureTriesDecodeInt8() { + @Test func testDecodeFailureTriesDecodeInt8() { let value: Int64 = 1 << 32 var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeFailureInvalidNumberOfArrayElements() { + @Test func testDecodeFailureInvalidNumberOfArrayElements() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // invalid value buffer.writeInteger(Int32(0)) @@ -139,12 +143,12 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(-123)) // expected element count buffer.writeInteger(Int32(1)) // dimensions... must be one - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeFailureInvalidNumberOfDimensions() { + @Test func testDecodeFailureInvalidNumberOfDimensions() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // invalid value buffer.writeInteger(Int32(0)) @@ -152,12 +156,12 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(1)) // expected element count buffer.writeInteger(Int32(2)) // dimensions... must be one - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeUnexpectedEnd() { + @Test func testDecodeUnexpectedEnd() { var unexpectedEndInElementLengthBuffer = ByteBuffer() unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // invalid value unexpectedEndInElementLengthBuffer.writeInteger(Int32(0)) @@ -166,8 +170,8 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 - XCTAssertThrowsError(try [String](from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default) } var unexpectedEndInElementBuffer = ByteBuffer() @@ -179,8 +183,8 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 unexpectedEndInElementBuffer.writeString("Hello World") // only 11 bytes, 12 needed! - XCTAssertThrowsError(try [String](from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default) } } } diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index e6e43f0b..d23eff08 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -1,89 +1,97 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class Bool_PSQLCodableTests: XCTestCase { +@Suite struct Bool_PSQLCodableTests { // MARK: - Binary - func testBinaryTrueRoundTrip() { + @Test func testBinaryTrueRoundTrip() { let value = true var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(Bool.psqlType, .bool) - XCTAssertEqual(Bool.psqlFormat, .binary) - XCTAssertEqual(buffer.readableBytes, 1) - XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) + #expect(Bool.psqlType == .bool) + #expect(Bool.psqlFormat == .binary) + #expect(buffer.readableBytes == 1) + #expect(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self) == 1) var result: Bool? - XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) - XCTAssertEqual(value, result) + #expect(throws: Never.self) { + result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default) + } + #expect(value == result) } - func testBinaryFalseRoundTrip() { + @Test func testBinaryFalseRoundTrip() { let value = false var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(Bool.psqlType, .bool) - XCTAssertEqual(Bool.psqlFormat, .binary) - XCTAssertEqual(buffer.readableBytes, 1) - XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 0) + #expect(Bool.psqlType == .bool) + #expect(Bool.psqlFormat == .binary) + #expect(buffer.readableBytes == 1) + #expect(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self) == 0) var result: Bool? - XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) - XCTAssertEqual(value, result) + #expect(throws: Never.self) { + result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default) + } + #expect(value == result) } - func testBinaryDecodeBoolInvalidLength() { + @Test func testBinaryDecodeBoolInvalidLength() { var buffer = ByteBuffer() buffer.writeInteger(Int64(1)) - XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try Bool(from: &buffer, type: .bool, format: .binary, context: .default) } } - func testBinaryDecodeBoolInvalidValue() { + @Test func testBinaryDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try Bool(from: &buffer, type: .bool, format: .binary, context: .default) } } // MARK: - Text - func testTextTrueDecode() { + @Test func testTextTrueDecode() { let value = true var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "t")) var result: Bool? - XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) - XCTAssertEqual(value, result) + #expect(throws: Never.self) { + result = try Bool(from: &buffer, type: .bool, format: .text, context: .default) + } + #expect(value == result) } - func testTextFalseDecode() { + @Test func testTextFalseDecode() { let value = false var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "f")) var result: Bool? - XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) - XCTAssertEqual(value, result) + #expect(throws: Never.self) { + result = try Bool(from: &buffer, type: .bool, format: .text, context: .default) + } + #expect(value == result) } - func testTextDecodeBoolInvalidValue() { + @Test func testTextDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .text, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try Bool(from: &buffer, type: .bool, format: .text, context: .default) } } } diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index 9230aee7..77051775 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -1,34 +1,35 @@ -import XCTest +import struct Foundation.Data +import Testing import NIOCore @testable import PostgresNIO -class Bytes_PSQLCodableTests: XCTestCase { - - func testDataRoundTrip() { +@Suite struct Bytes_PSQLCodableTests { + + @Test func testDataRoundTrip() { let data = Data((0...UInt8.max)) var buffer = ByteBuffer() data.encode(into: &buffer, context: .default) - XCTAssertEqual(ByteBuffer.psqlType, .bytea) - + #expect(ByteBuffer.psqlType == .bytea) + var result: Data? result = Data(from: &buffer, type: .bytea, format: .binary, context: .default) - XCTAssertEqual(data, result) + #expect(data == result) } - func testByteBufferRoundTrip() { + @Test func testByteBufferRoundTrip() { let bytes = ByteBuffer(bytes: (0...UInt8.max)) var buffer = ByteBuffer() bytes.encode(into: &buffer, context: .default) - XCTAssertEqual(ByteBuffer.psqlType, .bytea) - + #expect(ByteBuffer.psqlType == .bytea) + var result: ByteBuffer? result = ByteBuffer(from: &buffer, type: .bytea, format: .binary, context: .default) - XCTAssertEqual(bytes, result) + #expect(bytes == result) } - func testEncodeSequenceWhereElementUInt8() { + @Test func testEncodeSequenceWhereElementUInt8() { struct ByteSequence: Sequence, PostgresEncodable { typealias Element = UInt8 typealias Iterator = Array.Iterator @@ -47,7 +48,7 @@ class Bytes_PSQLCodableTests: XCTestCase { let sequence = ByteSequence() var buffer = ByteBuffer() sequence.encode(into: &buffer, context: .default) - XCTAssertEqual(ByteSequence.psqlType, .bytea) - XCTAssertEqual(buffer.readableBytes, 256) + #expect(ByteSequence.psqlType == .bytea) + #expect(buffer.readableBytes == 256) } } diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 769bde4b..3f406598 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -14,7 +14,7 @@ class Date_PSQLCodableTests: XCTestCase { var result: Date? XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) - XCTAssertEqual(value, result) + XCTAssertEqual(value.timeIntervalSince1970, result?.timeIntervalSince1970 ?? 0, accuracy: 0.001) } func testDecodeRandomDate() { diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 6ff35130..c1843c2a 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -31,18 +31,6 @@ class String_PSQLCodableTests: XCTestCase { } } - func testDecodeFailureFromInvalidType() { - let buffer = ByteBuffer() - let dataTypes: [PostgresDataType] = [.bool, .float4Array, .float8Array] - - for dataType in dataTypes { - var loopBuffer = buffer - XCTAssertThrowsError(try String(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) - } - } - } - func testDecodeFromUUID() { let uuid = UUID() var buffer = ByteBuffer() @@ -64,4 +52,15 @@ class String_PSQLCodableTests: XCTestCase { XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } + + func testDecodeFromJSONB() { + let json = #"{"hello": "world"}"# + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(1)) + buffer.writeString(json) + + var decoded: String? + XCTAssertNoThrow(decoded = try String(from: &buffer, type: .jsonb, format: .binary, context: .default)) + XCTAssertEqual(decoded, json) + } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 9614bf1e..0c6b37ef 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { case .commandComplete(let string): self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer) + case .copyInResponse(let copyInResponse): + self.encode(messageID: message.id, payload: copyInResponse, into: &buffer) case .dataRow(let row): self.encode(messageID: message.id, payload: row, into: &buffer) @@ -99,6 +101,8 @@ extension PostgresBackendMessage { return .closeComplete case .commandComplete: return .commandComplete + case .copyInResponse: + return .copyInResponse case .dataRow: return .dataRow case .emptyQueryResponse: @@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { } } +extension PostgresBackendMessage.CopyInResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int8(self.format.rawValue)) + buffer.writeInteger(Int16(self.columnFormats.count)) + for columnFormat in columnFormats { + buffer.writeInteger(Int16(columnFormat.rawValue)) + } + } +} + extension DataRow: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.columnCount, as: Int16.self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 55ccd0a9..d913da22 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -168,6 +168,18 @@ extension PostgresFrontendMessage { ) ) + case .copyData: + return .copyData(CopyData(data: buffer)) + + case .copyDone: + return .copyDone + + case .copyFail: + guard let message = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .copyFail(CopyFail(message: message)) + case .close: preconditionFailure("TODO: Unimplemented") diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 2532959a..5fc8144b 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -36,6 +36,14 @@ enum PostgresFrontendMessage: Equatable { let secretKey: Int32 } + struct CopyData: Hashable { + var data: ByteBuffer + } + + struct CopyFail: Hashable { + var message: String + } + enum Close: Hashable { case preparedStatement(String) case portal(String) @@ -170,6 +178,9 @@ enum PostgresFrontendMessage: Equatable { case bind(Bind) case cancel(Cancel) + case copyData(CopyData) + case copyDone + case copyFail(CopyFail) case close(Close) case describe(Describe) case execute(Execute) @@ -186,6 +197,9 @@ enum PostgresFrontendMessage: Equatable { enum ID: UInt8, Equatable { case bind + case copyData + case copyDone + case copyFail case close case describe case execute @@ -201,12 +215,18 @@ enum PostgresFrontendMessage: Equatable { switch rawValue { case UInt8(ascii: "B"): self = .bind + case UInt8(ascii: "c"): + self = .copyDone case UInt8(ascii: "C"): self = .close + case UInt8(ascii: "d"): + self = .copyData case UInt8(ascii: "D"): self = .describe case UInt8(ascii: "E"): self = .execute + case UInt8(ascii: "f"): + self = .copyFail case UInt8(ascii: "H"): self = .flush case UInt8(ascii: "P"): @@ -230,6 +250,12 @@ enum PostgresFrontendMessage: Equatable { switch self { case .bind: return UInt8(ascii: "B") + case .copyData: + return UInt8(ascii: "d") + case .copyDone: + return UInt8(ascii: "c") + case .copyFail: + return UInt8(ascii: "f") case .close: return UInt8(ascii: "C") case .describe: @@ -263,6 +289,12 @@ extension PostgresFrontendMessage { return .bind case .cancel: preconditionFailure("Cancel messages don't have an identifier") + case .copyData: + return .copyData + case .copyDone: + return .copyDone + case .copyFail: + return .copyFail case .close: return .close case .describe: diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 06e39aae..3b857157 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore import NIOTestUtils @testable import PostgresNIO -class AuthenticationTests: XCTestCase { - +@Suite struct AuthenticationTests { + func testDecodeAuthentication() { var expected = [PostgresBackendMessage]() var buffer = ByteBuffer() @@ -39,9 +39,11 @@ class AuthenticationTests: XCTestCase { encoder.encode(data: .authentication(.sspi), out: &buffer) expected.append(.authentication(.sspi)) - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: [(buffer, expected)], - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } - )) + #expect(throws: Never.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + ) + } } } diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift index d41607e3..204b544d 100644 --- a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore import NIOTestUtils @testable import PostgresNIO -class BackendKeyDataTests: XCTestCase { - func testDecode() { +@Suite struct BackendKeyDataTests { + @Test func testDecode() { let buffer = ByteBuffer.backendMessage(id: .backendKeyData) { buffer in buffer.writeInteger(Int32(1234)) buffer.writeInteger(Int32(4567)) @@ -14,12 +14,15 @@ class BackendKeyDataTests: XCTestCase { (buffer, [PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), ] - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: expectedInOuts, - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + #expect(throws: Never.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expectedInOuts, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + ) + } } - func testDecodeInvalidLength() { + @Test func testDecodeInvalidLength() { var buffer = ByteBuffer() buffer.psqlWriteBackendMessageID(.backendKeyData) buffer.writeInteger(Int32(11)) @@ -30,10 +33,11 @@ class BackendKeyDataTests: XCTestCase { (buffer, [PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), ] - XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: expected, - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PostgresMessageDecodingError) + #expect(throws: PostgresMessageDecodingError.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expected, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + ) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index d5ec5b30..24925fdf 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class BindTests: XCTestCase { - - func testEncodeBind() { +@Suite struct BindTests { + + @Test func testEncodeBind() { var bindings = PostgresBindings() bindings.append("Hello", context: .default) bindings.append("World", context: .default) @@ -14,34 +14,34 @@ class BindTests: XCTestCase { encoder.bind(portalName: "", preparedStatementName: "", bind: bindings) var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 37) - XCTAssertEqual(PostgresFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 37) + #expect(PostgresFrontendMessage.ID.bind.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(byteBuffer.readInteger(as: Int32.self) == 36) + #expect("" == byteBuffer.readNullTerminatedString()) + #expect("" == byteBuffer.readNullTerminatedString()) // the number of parameters - XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) + #expect(2 == byteBuffer.readInteger(as: Int16.self)) // all (two) parameters have the same format (binary) - XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) - XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) - + #expect(1 == byteBuffer.readInteger(as: Int16.self)) + #expect(1 == byteBuffer.readInteger(as: Int16.self)) + // read number of parameters - XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) - + #expect(2 == byteBuffer.readInteger(as: Int16.self)) + // hello length - XCTAssertEqual(5, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual("Hello", byteBuffer.readString(length: 5)) - + #expect(5 == byteBuffer.readInteger(as: Int32.self)) + #expect("Hello" == byteBuffer.readString(length: 5)) + // world length - XCTAssertEqual(5, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual("World", byteBuffer.readString(length: 5)) - + #expect(5 == byteBuffer.readInteger(as: Int32.self)) + #expect("World" == byteBuffer.readString(length: 5)) + // all response values have the same format: therefore one format byte is next - XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) + #expect(1 == byteBuffer.readInteger(as: Int16.self)) // all response values have the same format (binary) - XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) - + #expect(1 == byteBuffer.readInteger(as: Int16.self)) + // nothing left to read - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift index 5548aae3..c2da01d3 100644 --- a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -1,21 +1,20 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class CancelTests: XCTestCase { - - func testEncodeCancel() { +@Suite struct CancelTests { + @Test func testEncodeCancel() { let processID: Int32 = 1234 let secretKey: Int32 = 4567 var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.cancel(processID: processID, secretKey: secretKey) var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 16) - XCTAssertEqual(16, byteBuffer.readInteger(as: Int32.self)) // payload length - XCTAssertEqual(80877102, byteBuffer.readInteger(as: Int32.self)) // cancel request code - XCTAssertEqual(processID, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(secretKey, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 16) + #expect(16 == byteBuffer.readInteger(as: Int32.self)) // payload length + #expect(80877102 == byteBuffer.readInteger(as: Int32.self)) // cancel request code + #expect(processID == byteBuffer.readInteger(as: Int32.self)) + #expect(secretKey == byteBuffer.readInteger(as: Int32.self)) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index a8e1cfeb..9d6f1b37 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -1,31 +1,31 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class CloseTests: XCTestCase { - func testEncodeClosePortal() { +@Suite struct CloseTests { + @Test func testEncodeClosePortal() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.closePortal("Hello") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 12) - XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 12) + #expect(PostgresFrontendMessage.ID.close.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(11 == byteBuffer.readInteger(as: Int32.self)) + #expect(UInt8(ascii: "P") == byteBuffer.readInteger(as: UInt8.self)) + #expect("Hello" == byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 0) } - func testEncodeCloseUnnamedStatement() { + @Test func testEncodeCloseUnnamedStatement() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.closePreparedStatement("") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 7) - XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 7) + #expect(PostgresFrontendMessage.ID.close.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(6 == byteBuffer.readInteger(as: Int32.self)) + #expect(UInt8(ascii: "S") == byteBuffer.readInteger(as: UInt8.self)) + #expect("" == byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/CopyTests.swift b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift new file mode 100644 index 00000000..01136d05 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift @@ -0,0 +1,137 @@ +import Testing +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +@Suite struct CopyTests { + @Test func testDecodeCopyInResponseMessage() throws { + let expected: [PostgresBackendMessage] = [ + .copyInResponse(.init(format: .textual, columnFormats: [.textual, .textual])), + .copyInResponse(.init(format: .binary, columnFormats: [.binary, .binary])), + .copyInResponse(.init(format: .binary, columnFormats: [.textual, .binary])) + ] + + var buffer = ByteBuffer() + + for message in expected { + guard case .copyInResponse(let message) = message else { + Issue.record("Expected only to get copyInResponse here!") + return + } + buffer.writeBackendMessage(id: .copyInResponse ) { buffer in + buffer.writeInteger(Int8(message.format.rawValue)) + buffer.writeInteger(Int16(message.columnFormats.count)) + for columnFormat in message.columnFormats { + buffer.writeInteger(UInt16(columnFormat.rawValue)) + } + } + } + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + + @Test func testDecodeFailureBecauseOfEmptyMessage() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { _ in} + + #expect(throws: PostgresMessageDecodingError.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + } + + + @Test func testDecodeFailureBecauseOfInvalidFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + #expect(throws: PostgresMessageDecodingError.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + } + + @Test func testDecodeFailureBecauseOfMissingColumnNumber() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + } + + #expect(throws: PostgresMessageDecodingError.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + } + + @Test func testDecodeFailureBecauseOfMissingColumns() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(20)) // 20 columns promised, none given + } + + #expect(throws: PostgresMessageDecodingError.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + } + + @Test func testDecodeFailureBecauseOfInvalidColumnFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(1)) + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + #expect(throws: PostgresMessageDecodingError.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + } + + @Test func testEncodeCopyDataHeader() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDataHeader(dataLength: 3) + var byteBuffer = encoder.flushBuffer() + + #expect(byteBuffer.readableBytes == 5) + #expect(PostgresFrontendMessage.ID.copyData.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(byteBuffer.readInteger(as: Int32.self) == 7) + } + + @Test func testEncodeCopyDone() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDone() + var byteBuffer = encoder.flushBuffer() + + #expect(byteBuffer.readableBytes == 5) + #expect(PostgresFrontendMessage.ID.copyDone.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(byteBuffer.readInteger(as: Int32.self) == 4) + } + + @Test func testEncodeCopyFail() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyFail(message: "Oh, no :(") + var byteBuffer = encoder.flushBuffer() + + #expect(byteBuffer.readableBytes == 15) + #expect(PostgresFrontendMessage.ID.copyFail.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(byteBuffer.readInteger(as: Int32.self) == 14) + #expect(byteBuffer.readNullTerminatedString() == "Oh, no :(") + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index a90d1e93..f185877a 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -1,10 +1,11 @@ -import XCTest +import Foundation +import Testing import NIOCore import NIOTestUtils @testable import PostgresNIO -class DataRowTests: XCTestCase { - func testDecode() { +@Suite struct DataRowTests { + @Test func testDecode() { let buffer = ByteBuffer.backendMessage(id: .dataRow) { buffer in // the data row has 3 columns buffer.writeInteger(3, as: Int16.self) @@ -26,23 +27,26 @@ class DataRowTests: XCTestCase { (buffer, [PostgresBackendMessage.dataRow(.init(columnCount: 3, bytes: rowSlice))]), ] - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: expectedInOuts, - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + #expect(throws: Never.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expectedInOuts, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + ) + } } - func testIteratingElements() { + @Test func testIteratingElements() { let dataRow = DataRow.makeTestDataRow(nil, ByteBuffer(), ByteBuffer(repeating: 5, count: 10)) var iterator = dataRow.makeIterator() - XCTAssertEqual(dataRow.count, 3) - XCTAssertEqual(iterator.next(), .some(.none)) - XCTAssertEqual(iterator.next(), ByteBuffer()) - XCTAssertEqual(iterator.next(), ByteBuffer(repeating: 5, count: 10)) - XCTAssertEqual(iterator.next(), .none) + #expect(dataRow.count == 3) + #expect(iterator.next() == .some(.none)) + #expect(iterator.next() == ByteBuffer()) + #expect(iterator.next() == ByteBuffer(repeating: 5, count: 10)) + #expect(iterator.next() == .none) } - func testIndexAfterAndSubscript() { + @Test func testIndexAfterAndSubscript() { let dataRow = DataRow.makeTestDataRow( nil, ByteBuffer(), @@ -51,18 +55,18 @@ class DataRowTests: XCTestCase { ) var index = dataRow.startIndex - XCTAssertEqual(dataRow[index], .none) + #expect(dataRow[index] == .none) index = dataRow.index(after: index) - XCTAssertEqual(dataRow[index], ByteBuffer()) + #expect(dataRow[index] == ByteBuffer()) index = dataRow.index(after: index) - XCTAssertEqual(dataRow[index], ByteBuffer(repeating: 5, count: 10)) + #expect(dataRow[index] == ByteBuffer(repeating: 5, count: 10)) index = dataRow.index(after: index) - XCTAssertEqual(dataRow[index], .none) + #expect(dataRow[index] == .none) index = dataRow.index(after: index) - XCTAssertEqual(index, dataRow.endIndex) + #expect(index == dataRow.endIndex) } - func testIndexComparison() { + @Test func testIndexComparison() { let dataRow = DataRow.makeTestDataRow( nil, ByteBuffer(), @@ -73,18 +77,18 @@ class DataRowTests: XCTestCase { let startIndex = dataRow.startIndex let secondIndex = dataRow.index(after: startIndex) - XCTAssertLessThanOrEqual(startIndex, secondIndex) - XCTAssertLessThan(startIndex, secondIndex) - - XCTAssertGreaterThanOrEqual(secondIndex, startIndex) - XCTAssertGreaterThan(secondIndex, startIndex) - - XCTAssertFalse(secondIndex == startIndex) - XCTAssertEqual(secondIndex, secondIndex) - XCTAssertEqual(startIndex, startIndex) + #expect(startIndex <= secondIndex) + #expect(startIndex < secondIndex) + + #expect(secondIndex >= startIndex) + #expect(secondIndex > startIndex) + + #expect(secondIndex != startIndex) + #expect(secondIndex == secondIndex) + #expect(startIndex == startIndex) } - func testColumnSubscript() { + @Test func testColumnSubscript() { let dataRow = DataRow.makeTestDataRow( nil, ByteBuffer(), @@ -92,14 +96,14 @@ class DataRowTests: XCTestCase { nil ) - XCTAssertEqual(dataRow.count, 4) - XCTAssertEqual(dataRow[column: 0], .none) - XCTAssertEqual(dataRow[column: 1], ByteBuffer()) - XCTAssertEqual(dataRow[column: 2], ByteBuffer(repeating: 5, count: 10)) - XCTAssertEqual(dataRow[column: 3], .none) + #expect(dataRow.count == 4) + #expect(dataRow[column: 0] == .none) + #expect(dataRow[column: 1] == ByteBuffer()) + #expect(dataRow[column: 2] == ByteBuffer(repeating: 5, count: 10)) + #expect(dataRow[column: 3] == .none) } - func testWithContiguousStorageIfAvailable() { + @Test func testWithContiguousStorageIfAvailable() { let dataRow = DataRow.makeTestDataRow( nil, ByteBuffer(), @@ -107,9 +111,9 @@ class DataRowTests: XCTestCase { nil ) - XCTAssertNil(dataRow.withContiguousStorageIfAvailable { _ in - return XCTFail("DataRow does not have a contiguous storage") - }) + #expect(dataRow.withContiguousStorageIfAvailable { _ in + Issue.record("DataRow does not have a contiguous storage") + } == nil) } } diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index cb3c745b..42a521aa 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -1,33 +1,32 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class DescribeTests: XCTestCase { - - func testEncodeDescribePortal() { +@Suite struct DescribeTests { + @Test func testEncodeDescribePortal() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.describePortal("Hello") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 12) - XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 12) + #expect(PostgresFrontendMessage.ID.describe.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(11 == byteBuffer.readInteger(as: Int32.self)) + #expect(UInt8(ascii: "P") == byteBuffer.readInteger(as: UInt8.self)) + #expect("Hello" == byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 0) } - - func testEncodeDescribeUnnamedStatement() { + + @Test func testEncodeDescribeUnnamedStatement() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.describePreparedStatement("") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 7) - XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 7) + #expect(PostgresFrontendMessage.ID.describe.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(6 == byteBuffer.readInteger(as: Int32.self)) + #expect(UInt8(ascii: "S") == byteBuffer.readInteger(as: UInt8.self)) + #expect("" == byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 834ad0dd..985ab10e 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -1,18 +1,17 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class ExecuteTests: XCTestCase { - - func testEncodeExecute() { +@Suite struct ExecuteTests { + @Test func testEncodeExecute() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.execute(portalName: "", maxNumberOfRows: 0) var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) - XCTAssertEqual(PostgresFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(0, byteBuffer.readInteger(as: Int32.self)) + #expect(byteBuffer.readableBytes == 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) + #expect(PostgresFrontendMessage.ID.execute.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(9 == byteBuffer.readInteger(as: Int32.self)) // length + #expect("" == byteBuffer.readNullTerminatedString()) + #expect(0 == byteBuffer.readInteger(as: Int32.self)) } } diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 9f81e4e4..e40dbbfe 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -1,9 +1,9 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class ParseTests: XCTestCase { - func testEncode() { +@Suite struct ParseTests { + @Test func testEncode() { let preparedStatementName = "test" let query = "SELECT version()" let parameters: [PostgresDataType] = [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray] @@ -22,14 +22,14 @@ class ParseTests: XCTestCase { // + 4 preparedStatement (3 + 1 null terminator) // + 1 query () - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.parse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), preparedStatementName) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), query) - XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parameters.count)) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.parse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readNullTerminatedString() == preparedStatementName) + #expect(byteBuffer.readNullTerminatedString() == query) + #expect(byteBuffer.readInteger(as: UInt16.self) == UInt16(parameters.count)) for dataType in parameters { - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), dataType.rawValue) + #expect(byteBuffer.readInteger(as: UInt32.self) == dataType.rawValue) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 4a4833d2..cf4ad83f 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -1,10 +1,9 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class PasswordTests: XCTestCase { - - func testEncodePassword() { +@Suite struct PasswordTests { + @Test func testEncodePassword() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) // md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4 let password = "md522d085ed8dc3377968dc1c1a40519a2a" @@ -13,9 +12,9 @@ class PasswordTests: XCTestCase { let expectedLength = 41 // 1 (id) + 4 (length) + 35 (string) + 1 (null termination) - XCTAssertEqual(byteBuffer.readableBytes, expectedLength) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.password.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(expectedLength - 1)) // length - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") + #expect(byteBuffer.readableBytes == expectedLength) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.password.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(expectedLength - 1)) // length + #expect(byteBuffer.readNullTerminatedString() == "md522d085ed8dc3377968dc1c1a40519a2a") } } diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index 90aa6b34..7ba31057 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class SASLInitialResponseTests: XCTestCase { +@Suite struct SASLInitialResponseTests { - func testEncode() { + @Test func testEncode() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) let saslMechanism = "hello" let initialData: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] @@ -19,16 +19,16 @@ class SASLInitialResponseTests: XCTestCase { // + 4 initialData length // + 8 initialData - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(initialData.count)) - XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.saslInitialResponse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readNullTerminatedString() == saslMechanism) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(initialData.count)) + #expect(byteBuffer.readBytes(length: initialData.count) == initialData) + #expect(byteBuffer.readableBytes == 0) } - func testEncodeWithoutData() { + @Test func testEncodeWithoutData() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) let saslMechanism = "hello" let initialData: [UInt8] = [] @@ -43,12 +43,12 @@ class SASLInitialResponseTests: XCTestCase { // + 4 initialData length // + 0 initialData - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) - XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.saslInitialResponse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readNullTerminatedString() == saslMechanism) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(-1)) + #expect(byteBuffer.readBytes(length: initialData.count) == initialData) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index cdb0f10b..a2a06418 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class SASLResponseTests: XCTestCase { +@Suite struct SASLResponseTests { - func testEncodeWithData() { + @Test func testEncodeWithData() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) let data: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] encoder.saslResponse(data) @@ -12,14 +12,14 @@ class SASLResponseTests: XCTestCase { let length: Int = 1 + 4 + (data.count) - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readBytes(length: data.count), data) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.saslResponse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readBytes(length: data.count) == data) + #expect(byteBuffer.readableBytes == 0) } - func testEncodeWithoutData() { + @Test func testEncodeWithoutData() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) let data: [UInt8] = [] encoder.saslResponse(data) @@ -27,9 +27,9 @@ class SASLResponseTests: XCTestCase { let length: Int = 1 + 4 - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.saslResponse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 5af3bf34..23d022d9 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -1,10 +1,9 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class StartupTests: XCTestCase { - - func testStartupMessageWithDatabase() { +@Suite struct StartupTests { + @Test func testStartupMessageWithDatabase() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() @@ -15,18 +14,18 @@ class StartupTests: XCTestCase { byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) - XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") - XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + #expect(byteBufferLength == byteBuffer.readInteger()) + #expect(PostgresFrontendMessage.Startup.versionThree == byteBuffer.readInteger()) + #expect(byteBuffer.readNullTerminatedString() == "user") + #expect(byteBuffer.readNullTerminatedString() == "test") + #expect(byteBuffer.readNullTerminatedString() == "database") + #expect(byteBuffer.readNullTerminatedString() == "abc123") + #expect(byteBuffer.readInteger() == UInt8(0)) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 0) } - func testStartupMessageWithoutDatabase() { + @Test func testStartupMessageWithoutDatabase() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() @@ -36,16 +35,16 @@ class StartupTests: XCTestCase { byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) - XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + #expect(byteBufferLength == byteBuffer.readInteger()) + #expect(PostgresFrontendMessage.Startup.versionThree == byteBuffer.readInteger()) + #expect(byteBuffer.readNullTerminatedString() == "user") + #expect(byteBuffer.readNullTerminatedString() == "test") + #expect(byteBuffer.readInteger() == UInt8(0)) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 0) } - func testStartupMessageWithAdditionalOptions() { + @Test func testStartupMessageWithAdditionalOptions() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() @@ -56,17 +55,17 @@ class StartupTests: XCTestCase { byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) - XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") - XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) - - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBufferLength == byteBuffer.readInteger()) + #expect(PostgresFrontendMessage.Startup.versionThree == byteBuffer.readInteger()) + #expect(byteBuffer.readNullTerminatedString() == "user") + #expect(byteBuffer.readNullTerminatedString() == "test") + #expect(byteBuffer.readNullTerminatedString() == "database") + #expect(byteBuffer.readNullTerminatedString() == "abc123") + #expect(byteBuffer.readNullTerminatedString() == "some") + #expect(byteBuffer.readNullTerminatedString() == "options") + #expect(byteBuffer.readInteger() == UInt8(0)) + + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 9a1e9e41..65ca26c3 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -12,7 +12,7 @@ final class PSQLRowStreamTests: XCTestCase { func testEmptyStream() { let stream = PSQLRowStream( - source: .noRows(.success("INSERT 0 1")), + source: .noRows(.success(.tag("INSERT 0 1"))), eventLoop: self.eventLoop, logger: self.logger ) diff --git a/Tests/PostgresNIOTests/New/PostgresCellTests.swift b/Tests/PostgresNIOTests/New/PostgresCellTests.swift index 6458d063..4ed64b3f 100644 --- a/Tests/PostgresNIOTests/New/PostgresCellTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCellTests.swift @@ -1,9 +1,9 @@ @testable import PostgresNIO -import XCTest +import Testing import NIOCore -final class PostgresCellTests: XCTestCase { - func testDecodingANonOptionalString() { +@Suite struct PostgresCellTests { + @Test func testDecodingANonOptionalString() { let cell = PostgresCell( bytes: ByteBuffer(string: "Hello world"), dataType: .text, @@ -13,11 +13,13 @@ final class PostgresCellTests: XCTestCase { ) var result: String? - XCTAssertNoThrow(result = try cell.decode(String.self, context: .default)) - XCTAssertEqual(result, "Hello world") + #expect(throws: Never.self) { + result = try cell.decode(String.self, context: .default) + } + #expect(result == "Hello world") } - func testDecodingAnOptionalString() { + @Test func testDecodingAnOptionalString() { let cell = PostgresCell( bytes: nil, dataType: .text, @@ -27,11 +29,13 @@ final class PostgresCellTests: XCTestCase { ) var result: String? = "test" - XCTAssertNoThrow(result = try cell.decode(String?.self, context: .default)) - XCTAssertNil(result) + #expect(throws: Never.self) { + result = try cell.decode(String?.self, context: .default) + } + #expect(result == nil) } - func testDecodingFailure() { + @Test func testDecodingFailure() { let cell = PostgresCell( bytes: ByteBuffer(string: "Hello world"), dataType: .text, @@ -40,19 +44,44 @@ final class PostgresCellTests: XCTestCase { columnIndex: 1 ) - XCTAssertThrowsError(try cell.decode(Int?.self, context: .default)) { - guard let error = $0 as? PostgresDecodingError else { - return XCTFail("Unexpected error") + #if compiler(>=6.1) + let error = #expect(throws: PostgresDecodingError.self) { + try cell.decode(Int?.self, context: .default) + } + guard let error else { + Issue.record("Expected error at this point") + return + } + + #expect(error.file == #fileID) + #expect(error.line == #line - 9) + #expect(error.code == .typeMismatch) + #expect(error.columnName == "hello") + #expect(error.columnIndex == 1) + let correctType = error.targetType == Int?.self + #expect(correctType) + #expect(error.postgresType == .text) + #expect(error.postgresFormat == .binary) + #else + do { + _ = try cell.decode(Int?.self, context: .default) + Issue.record("Expected to throw") + } catch { + guard let error = error as? PostgresDecodingError else { + Issue.record("Expected error at this point") + return } - XCTAssertEqual(error.file, #fileID) - XCTAssertEqual(error.line, #line - 6) - XCTAssertEqual(error.code, .typeMismatch) - XCTAssertEqual(error.columnName, "hello") - XCTAssertEqual(error.columnIndex, 1) - XCTAssert(error.targetType == Int?.self) - XCTAssertEqual(error.postgresType, .text) - XCTAssertEqual(error.postgresFormat, .binary) + #expect(error.file == #fileID) + #expect(error.line == #line - 9) + #expect(error.code == .typeMismatch) + #expect(error.columnName == "hello") + #expect(error.columnIndex == 1) + let correctType = error.targetType == Int?.self + #expect(correctType) + #expect(error.postgresType == .text) + #expect(error.postgresFormat == .binary) } + #endif } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index dfdcc53e..206f38a3 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -48,7 +48,7 @@ class PostgresChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in addSSLCallbackIsHit = true } let embedded = EmbeddedChannel(handlers: [ @@ -84,7 +84,7 @@ class PostgresChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in addSSLCallbackIsHit = true } let eventHandler = TestEventHandler() @@ -114,7 +114,7 @@ class PostgresChannelHandlerTests: XCTestCase { func testSSLUnsupportedClosesConnection() throws { let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in XCTFail("This callback should never be exectuded") throw PSQLError.sslUnsupported } @@ -124,7 +124,7 @@ class PostgresChannelHandlerTests: XCTestCase { handler ], loop: self.eventLoop) let eventHandler = TestEventHandler() - try embedded.pipeline.addHandler(eventHandler, position: .last).wait() + try embedded.pipeline.syncOperations.addHandler(eventHandler, position: .last) embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil) XCTAssertTrue(embedded.isActive) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 5c7d4c83..b4658079 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -1,27 +1,29 @@ import NIOCore import NIOPosix import NIOEmbedded -import XCTest +import Testing import Logging @testable import PostgresNIO -class PostgresConnectionTests: XCTestCase { +@Suite struct PostgresConnectionTests { let logger = Logger(label: "PostgresConnectionTests") - func testConnectionFailure() { + @Test func testConnectionFailure() { // We start a local server and close it immediately to ensure that the port // number we try to connect to is not used by any other process. - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } - + let eventLoopGroup = NIOSingletons.posixEventLoopGroup + var tempChannel: Channel? - XCTAssertNoThrow(tempChannel = try ServerBootstrap(group: eventLoopGroup) - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait()) + #expect(throws: Never.self) { + tempChannel = try ServerBootstrap(group: eventLoopGroup) + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait() + } let maybePort = tempChannel?.localAddress?.port - XCTAssertNoThrow(try tempChannel?.close().wait()) + #expect(throws: Never.self) { try tempChannel?.close().wait() } guard let port = maybePort else { - return XCTFail("Could not get port number from temp started server") + Issue.record("Could not get port number from temp started server") + return } let config = PostgresConnection.Configuration( @@ -33,17 +35,19 @@ class PostgresConnectionTests: XCTestCase { var logger = Logger.psqlTest logger.logLevel = .trace - XCTAssertThrowsError(try PostgresConnection.connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger).wait()) { - XCTAssertTrue($0 is PSQLError) + #expect(throws: PSQLError.self) { + try PostgresConnection + .connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger) + .wait() } } - func testOptionsAreSentOnTheWire() async throws { + @Test func testOptionsAreSentOnTheWire() async throws { let eventLoop = NIOAsyncTestingEventLoop() - let channel = await NIOAsyncTestingChannel(handlers: [ - ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), - ], loop: eventLoop) + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) let configuration = { @@ -71,7 +75,7 @@ class PostgresConnectionTests: XCTestCase { async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + #expect(message == .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -80,326 +84,275 @@ class PostgresConnectionTests: XCTestCase { try await connection.close() } - func testSimpleListen() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - for try await event in events { - XCTAssertEqual(event.payload, "wooohooo") - break + @Test func testSimpleListen() async throws { + try await self.withAsyncTestingChannel { connection, channel in + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + for try await event in events { + #expect(event.payload == "wooohooo") + break + } } - } - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - - let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } - } - } + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) - func testSimpleListenDoesNotUnlistenIfThereIsAnotherSubscriber() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - for try await event in events { - XCTAssertEqual(event.payload, "wooohooo") - break - } - } + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - taskGroup.addTask { - let events = try await connection.listen("foo") - var counter = 0 - loop: for try await event in events { - defer { counter += 1 } - switch counter { - case 0: - XCTAssertEqual(event.payload, "wooohooo") - case 1: - XCTAssertEqual(event.payload, "wooohooo2") - break loop - default: - XCTFail("Unexpected message: \(event)") - } - } - } + let unlistenMessage = try await channel.waitForUnpreparedRequest() + #expect(unlistenMessage.parse.query == #"UNLISTEN "foo";"#) - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) - - let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } - } - } + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - func testSimpleListenConnectionDrops() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - var iterator = events.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "wooohooo") - do { - _ = try await iterator.next() - XCTFail("Did not expect to not throw") - } catch { - logger.error("error", metadata: ["error": "\(error)"]) + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") } } - - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - struct MyWeirdError: Error {} - channel.pipeline.fireErrorCaught(MyWeirdError()) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } } } - func testSimpleListenFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testSimpleListenDoesNotUnlistenIfThereIsAnotherSubscriber() async throws { + try await self.withAsyncTestingChannel { connection, channel in - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.listen("test_channel") - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + for try await event in events { + #expect(event.payload == "wooohooo") + break + } + } - func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - var iterator = events.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "wooohooo") - do { - _ = try await iterator.next() - XCTFail("Did not expect to not throw") - } catch let error as PSQLError { - XCTAssertEqual(error.code, .clientClosedConnection) + taskGroup.addTask { + let events = try await connection.listen("foo") + var counter = 0 + loop: for try await event in events { + defer { counter += 1 } + switch counter { + case 0: + #expect(event.payload == "wooohooo") + case 1: + #expect(event.payload == "wooohooo2") + break loop + default: + Issue.record("Unexpected message: \(event)") + } + } } - } - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) - try await connection.close() + let unlistenMessage = try await channel.waitForUnpreparedRequest() + #expect(unlistenMessage.parse.query == #"UNLISTEN "foo";"#) - XCTAssertEqual(channel.isActive, false) + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") + } } } } - func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in - for _ in 1...2 { + @Test func testSimpleListenConnectionDrops() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in taskGroup.addTask { - let rows = try await connection.query("SELECT 1;", logger: logger) - var iterator = rows.decode(Int.self).makeAsyncIterator() + let events = try await connection.listen("foo") + var iterator = events.makeAsyncIterator() let first = try await iterator.next() - XCTAssertEqual(first, 1) - let second = try await iterator.next() - XCTAssertNil(second) + #expect(first?.payload == "wooohooo") + do { + _ = try await iterator.next() + Issue.record("Did not expect to not throw") + } catch { + logger.error("error", metadata: ["error": "\(error)"]) + } } - } - for i in 0...1 { let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") - - if i == 0 { - taskGroup.addTask { - try await connection.closeGracefully() - } - } + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - let intDescription = RowDescription.Column( - name: "", - tableOID: 0, - columnAttributeNumber: 0, - dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary - ) - try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) - try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.noData) try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) - try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - } - let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(terminate, .terminate) - try await channel.closeFuture.get() - XCTAssertEqual(channel.isActive, false) + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + struct MyWeirdError: Error {} + channel.pipeline.fireErrorCaught(MyWeirdError()) - while let taskResult = await taskGroup.nextResult() { - switch taskResult { + switch await taskGroup.nextResult()! { case .success: break case .failure(let failure): - XCTFail("Unexpected error: \(failure)") + Issue.record("Unexpected error: \(failure)") } } } } - func testCloseClosesImmediatly() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { + try await self.withAsyncTestingChannel { connection, channel in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + let rows = try await connection.query("SELECT 1;", logger: logger) + var iterator = rows.decode(Int.self).makeAsyncIterator() + let first = try await iterator.next() + #expect(first == 1) + let second = try await iterator.next() + #expect(second == nil) + } + } - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in - for _ in 1...2 { - taskGroup.addTask { - try await connection.query("SELECT 1;", logger: logger) + for i in 0...1 { + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") + + if i == 0 { + taskGroup.addTask { + try await connection.closeGracefully() + } + } + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + let intDescription = RowDescription.Column( + name: "", + tableOID: 0, + columnAttributeNumber: 0, + dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary + ) + try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + + let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + #expect(terminate == .terminate) + try await channel.closeFuture.get() + #expect(!channel.isActive) + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") + } } } + } + } - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + @Test func testCloseClosesImmediatly() async throws { + try await self.withAsyncTestingChannel { connection, channel in - async let close: () = connection.close() + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + try await connection.query("SELECT 1;", logger: logger) + } + } - try await channel.closeFuture.get() - XCTAssertEqual(channel.isActive, false) + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") - try await close + async let close: () = connection.close() - while let taskResult = await taskGroup.nextResult() { - switch taskResult { - case .success: - XCTFail("Expected queries to fail") - case .failure(let failure): - guard let error = failure as? PSQLError else { - return XCTFail("Unexpected error type: \(failure)") + try await channel.closeFuture.get() + #expect(!channel.isActive) + + try await close + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + Issue.record("Expected queries to fail") + case .failure(let failure): + guard let error = failure as? PSQLError else { + Issue.record("Unexpected error type: \(failure)") + return + } + #expect(error.code == .clientClosedConnection) } - XCTAssertEqual(error.code, .clientClosedConnection) } } } } - func testIfServerJustClosesTheErrorReflectsThat() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - let logger = self.logger + @Test func testIfServerJustClosesTheErrorReflectsThat() async throws { + try await self.withAsyncTestingChannel { connection, channel in + let logger = self.logger - async let response = try await connection.query("SELECT 1;", logger: logger) + async let response = try await connection.query("SELECT 1;", logger: logger) - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") - try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } - try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } - do { - _ = try await response - XCTFail("Expected to throw") - } catch { - XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) - } + do { + _ = try await response + Issue.record("Expected to throw") + } catch { + #expect((error as? PSQLError)?.code == .serverClosedConnection) + } - // retry on same connection + // retry on same connection - do { - _ = try await connection.query("SELECT 1;", logger: self.logger) - XCTFail("Expected to throw") - } catch { - XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + do { + _ = try await connection.query("SELECT 1;", logger: self.logger) + Issue.record("Expected to throw") + } catch { + #expect((error as? PSQLError)?.code == .serverClosedConnection) + } } } @@ -420,399 +373,292 @@ class PostgresConnectionTests: XCTestCase { } } - func testPreparedStatement() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) - } - XCTAssertEqual(rows, 1) - } - - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) - - let preparedRequest = try await channel.waitForPreparedRequest() - XCTAssertEqual(preparedRequest.bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) - XCTAssertEqual(preparedRequest.bind.parameters.count, 1) - XCTAssertEqual(preparedRequest.bind.resultColumnFormats, [.binary]) + @Test func testPreparedStatement() async throws { + try await self.withAsyncTestingChannel { connection, channel in - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) - } - } - - func testWeDontCrashOnUnexpectedChannelEvents() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - enum MyEvent { - case pleaseDontCrash - } - channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) - try await connection.close() - } - - func testSerialExecutionOfSamePreparedStatement() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Send one first and wait to send the other request until preparation is complete - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) - - let preparedRequest1 = try await channel.waitForPreparedRequest() - var buffer = preparedRequest1.bind.parameters[0]! - let parameter1 = buffer.readString(length: buffer.readableBytes)! - XCTAssertEqual(parameter1, "active") - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) - - // Now that the statement has been prepared and executed, send another request that will only get executed - // without preparation - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") } - XCTAssertEqual(rows, 1) - } - - let preparedRequest2 = try await channel.waitForPreparedRequest() - buffer = preparedRequest2.bind.parameters[0]! - let parameter2 = buffer.readString(length: buffer.readableBytes)! - XCTAssertEqual(parameter2, "idle") - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) - // Ensure we received and responded to both the requests - let parameters = [parameter1, parameter2] - XCTAssert(parameters.contains("active")) - XCTAssert(parameters.contains("idle")) - } - } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Let them race to tests that requests and responses aren't mixed up - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database_active", database) - } - XCTAssertEqual(rows, 1) - } - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database_idle", database) - } - XCTAssertEqual(rows, 1) - } + let preparedRequest = try await channel.waitForPreparedRequest() + #expect(preparedRequest.bind.preparedStatementName == String(reflecting: TestPrepareStatement.self)) + #expect(preparedRequest.bind.parameters.count == 1) + #expect(preparedRequest.bind.resultColumnFormats == [.binary]) - // The channel deduplicates prepare requests, we're going to see only one of them - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) - - // Now both the tasks have their statements prepared. - // We should see both of their execute requests coming in, the order is nondeterministic - let preparedRequest1 = try await channel.waitForPreparedRequest() - var buffer = preparedRequest1.bind.parameters[0]! - let parameter1 = buffer.readString(length: buffer.readableBytes)! - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database_\(parameter1)"] - ], - commandTag: TestPrepareStatement.sql - ) - let preparedRequest2 = try await channel.waitForPreparedRequest() - buffer = preparedRequest2.bind.parameters[0]! - let parameter2 = buffer.readString(length: buffer.readableBytes)! - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database_\(parameter2)"] - ], - commandTag: TestPrepareStatement.sql - ) - // Ensure we received and responded to both the requests - let parameters = [parameter1, parameter2] - XCTAssert(parameters.contains("active")) - XCTAssert(parameters.contains("idle")) } } - func testStatementPreparationFailure() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Send one first and wait to send the other request until preparation is complete - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - do { - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Was supposed to fail") - } catch { - XCTAssert(error is PSQLError) - } - } + @Test func testWeDontCrashOnUnexpectedChannelEvents() async throws { + try await self.withAsyncTestingChannel { connection, channel in - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - // Respond with an error taking care to return a SQLSTATE that isn't - // going to get the connection closed. - try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ - .sqlState : "26000" // invalid_sql_statement_name - ]))) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - try await channel.testingEventLoop.executeInContext { channel.read() } - - - // Send another requests with the same prepared statement, which should fail straight - // away without any interaction with the server - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - do { - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Was supposed to fail") - } catch { - XCTAssert(error is PSQLError) - } + enum MyEvent { + case pleaseDontCrash } + channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + try await connection.close() } } - func testQueryFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.query("SELECT version;", logger: self.logger) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testPrepareStatementFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() + @Test func testSerialExecutionOfSamePreparedStatement() async throws { + try await self.withAsyncTestingChannel { connection, channel in - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get() - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) + } - func testExecuteFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - try await connection.closeGracefully() + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + #expect(parameter1 == "active") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) - XCTAssertEqual(channel.isActive, false) + // Now that the statement has been prepared and executed, send another request that will only get executed + // without preparation + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) + } - do { - let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil) - _ = try await connection.execute(statement, logger: .psqlTest).get() - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + #expect(parameter2 == "idle") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + #expect(parameters.contains("active")) + #expect(parameters.contains("idle")) + } } } - func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) + @Test func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { + try await self.withAsyncTestingChannel { connection, channel in - struct TestPreparedStatement: PostgresPreparedStatement { - static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" - typealias Row = (Int, String) - - var state: String + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Let them race to tests that requests and responses aren't mixed up + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database_active" == database) + } + #expect(rows == 1) + } + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database_idle" == database) + } + #expect(rows == 1) + } - func makeBindings() -> PostgresBindings { - var bindings = PostgresBindings() - bindings.append(self.state) - return bindings - } + // The channel deduplicates prepare requests, we're going to see only one of them + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { - try row.decode(Row.self) + // Now both the tasks have their statements prepared. + // We should see both of their execute requests coming in, the order is nondeterministic + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter1)"] + ], + commandTag: TestPrepareStatement.sql + ) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter2)"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + #expect(parameters.contains("active")) + #expect(parameters.contains("idle")) } } - - do { - let preparedStatement = TestPreparedStatement(state: "active") - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } } - func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testStatementPreparationFailure() async throws { + try await self.withAsyncTestingChannel { connection, channel in - try await connection.closeGracefully() + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + Issue.record("Was supposed to fail") + } catch { + #expect(error is PSQLError) + } + } - XCTAssertEqual(channel.isActive, false) + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) - struct TestPreparedStatement: PostgresPreparedStatement { - static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1" - typealias Row = () + // Respond with an error taking care to return a SQLSTATE that isn't + // going to get the connection closed. + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .sqlState : "26000" // invalid_sql_statement_name + ]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } - var state: String - func makeBindings() -> PostgresBindings { - var bindings = PostgresBindings() - bindings.append(self.state) - return bindings - } - - func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { - () + // Send another requests with the same prepared statement, which should fail straight + // away without any interaction with the server + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + Issue.record("Was supposed to fail") + } catch { + #expect(error is PSQLError) + } + } } } - - do { - let preparedStatement = TestPreparedStatement(state: "active") - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } } - func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { + func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws { let eventLoop = NIOAsyncTestingEventLoop() - let channel = await NIOAsyncTestingChannel(handlers: [ - ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), - ], loop: eventLoop) + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) let configuration = PostgresConnection.Configuration( @@ -825,18 +671,20 @@ class PostgresConnectionTests: XCTestCase { let logger = self.logger async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) + #expect(message == .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) let connection = try await connectionPromise - self.addTeardownBlock { - try await connection.close() + do { + try await body(connection, channel) + } catch { + } - return (connection, channel) + try await connection.close() } } diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 816daf04..54f13e96 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -1,17 +1,17 @@ import Atomics import NIOEmbedded -import Dispatch -import XCTest +import NIOPosix +import Testing @testable import PostgresNIO import NIOCore import Logging -final class PostgresRowSequenceTests: XCTestCase { +@Suite struct PostgresRowSequenceTests { let logger = Logger(label: "PSQLRowStreamTests") - let eventLoop = EmbeddedEventLoop() - func testBackpressureWorks() async throws { + @Test func testBackpressureWorks() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -19,28 +19,29 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) let rowSequence = stream.asyncSequence() - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) let dataRow: DataRow = [ByteBuffer(integer: Int64(1))] stream.receive([dataRow]) var iterator = rowSequence.makeAsyncIterator() let row = try await iterator.next() - XCTAssertEqual(dataSource.requestCount, 1) - XCTAssertEqual(row?.data, dataRow) + #expect(dataSource.requestCount == 1) + #expect(row?.data == dataRow) stream.receive(completion: .success("SELECT 1")) let empty = try await iterator.next() - XCTAssertNil(empty) + #expect(empty == nil) } - func testCancellationWorksWhileIterating() async throws { + @Test func testCancellationWorksWhileIterating() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -48,18 +49,18 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) let rowSequence = stream.asyncSequence() - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self), counter) + #expect(try row.decode(Int.self) == counter) counter += 1 if counter == 64 { @@ -67,11 +68,12 @@ final class PostgresRowSequenceTests: XCTestCase { } } - XCTAssertEqual(dataSource.cancelCount, 1) + #expect(dataSource.cancelCount == 1) } - func testCancellationWorksBeforeIterating() async throws { + @Test func testCancellationWorksBeforeIterating() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -79,24 +81,25 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) let rowSequence = stream.asyncSequence() - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) var iterator: PostgresRowSequence.AsyncIterator? = rowSequence.makeAsyncIterator() iterator = nil - XCTAssertEqual(dataSource.cancelCount, 1) - XCTAssertNil(iterator, "Surpress warning") + #expect(dataSource.cancelCount == 1) + #expect(iterator == nil, "Surpress warning") } - func testDroppingTheSequenceCancelsTheSource() async throws { + @Test func testDroppingTheSequenceCancelsTheSource() throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -104,19 +107,20 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) var rowSequence: PostgresRowSequence? = stream.asyncSequence() rowSequence = nil - XCTAssertEqual(dataSource.cancelCount, 1) - XCTAssertNil(rowSequence, "Surpress warning") + #expect(dataSource.cancelCount == 1) + #expect(rowSequence == nil, "Surpress warning") } - func testStreamBasedOnCompletedQuery() async throws { + @Test func testStreamBasedOnCompletedQuery() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -124,7 +128,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -135,15 +139,16 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self), counter) + #expect(try row.decode(Int.self) == counter) counter += 1 } - XCTAssertEqual(dataSource.cancelCount, 0) + #expect(dataSource.cancelCount == 0) } - func testStreamIfInitializedWithAllData() async throws { + @Test func testStreamIfInitializedWithAllData() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -151,7 +156,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -163,15 +168,16 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self), counter) + #expect(try row.decode(Int.self) == counter) counter += 1 } - XCTAssertEqual(dataSource.cancelCount, 0) + #expect(dataSource.cancelCount == 0) } - func testStreamIfInitializedWithError() async throws { + @Test func testStreamIfInitializedWithError() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -179,7 +185,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -192,14 +198,15 @@ final class PostgresRowSequenceTests: XCTestCase { for try await _ in rowSequence { counter += 1 } - XCTFail("Expected that an error was thrown before.") + Issue.record("Expected that an error was thrown before.") } catch { - XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) + #expect(error as? PSQLError == .serverClosedConnection(underlying: nil)) } } - func testSucceedingRowContinuationsWorks() async throws { + @Test func testSucceedingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( source: .stream( [ @@ -207,31 +214,32 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: eventLoop, logger: self.logger ) - let rowSequence = stream.asyncSequence() + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() var rowIterator = rowSequence.makeAsyncIterator() - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self), 0) + #expect(try row1?.decode(Int.self) == 0) - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .success("SELECT 1")) } let row2 = try await rowIterator.next() - XCTAssertNil(row2) + #expect(row2 == nil) } - func testFailingRowContinuationsWorks() async throws { + @Test func testFailingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( source: .stream( [ @@ -239,35 +247,36 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: eventLoop, logger: self.logger ) - let rowSequence = stream.asyncSequence() + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() var rowIterator = rowSequence.makeAsyncIterator() - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self), 0) + #expect(try row1?.decode(Int.self) == 0) - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) } do { _ = try await rowIterator.next() - XCTFail("Expected that an error was thrown before.") + Issue.record("Expected that an error was thrown before.") } catch { - XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) + #expect(error as? PSQLError == .serverClosedConnection(underlying: nil)) } } - func testAdaptiveRowBufferShrinksAndGrows() async throws { + @Test func testAdaptiveRowBufferShrinksAndGrows() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -275,7 +284,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -285,20 +294,20 @@ final class PostgresRowSequenceTests: XCTestCase { let rowSequence = stream.asyncSequence() var rowIterator = rowSequence.makeAsyncIterator() - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) _ = try await rowIterator.next() // new buffer size will be target -> don't ask for more - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - XCTAssertEqual(dataSource.requestCount, 1) + #expect(dataSource.requestCount == 1) // if the buffer gets new rows so that it has equal or more than target (the target size // should be halved), however shrinking is only allowed AFTER the first extra rows were // received. let addDataRows1: [DataRow] = [[ByteBuffer(integer: Int64(0))]] stream.receive(addDataRows1) - XCTAssertEqual(dataSource.requestCount, 1) + #expect(dataSource.requestCount == 1) _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - XCTAssertEqual(dataSource.requestCount, 2) + #expect(dataSource.requestCount == 2) // if the buffer gets new rows so that it has equal or more than target (the target size // should be halved) @@ -307,31 +316,32 @@ final class PostgresRowSequenceTests: XCTestCase { _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget / 2) { _ = try await rowIterator.next() // Remove all rows until we are back at target - XCTAssertEqual(dataSource.requestCount, 2) + #expect(dataSource.requestCount == 2) } // if we remove another row we should trigger getting new rows. _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - XCTAssertEqual(dataSource.requestCount, 3) + #expect(dataSource.requestCount == 3) // remove all remaining rows... this will trigger a target size double for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget/2 - 1) { _ = try await rowIterator.next() // Remove all rows until we are back at target - XCTAssertEqual(dataSource.requestCount, 3) + #expect(dataSource.requestCount == 3) } let fillBufferDataRows: [DataRow] = (0.. don't ask for more - XCTAssertEqual(dataSource.requestCount, 3) + #expect(dataSource.requestCount == 3) _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - XCTAssertEqual(dataSource.requestCount, 4) + #expect(dataSource.requestCount == 4) } - func testAdaptiveRowShrinksToMin() async throws { + @Test func testAdaptiveRowShrinksToMin() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -339,7 +349,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -352,9 +362,9 @@ final class PostgresRowSequenceTests: XCTestCase { var rowIterator = rowSequence.makeAsyncIterator() // shrinking the buffer is only allowed after the first extra rows were received - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) _ = try await rowIterator.next() - XCTAssertEqual(dataSource.requestCount, 1) + #expect(dataSource.requestCount == 1) stream.receive([[ByteBuffer(integer: Int64(1))]]) @@ -363,10 +373,10 @@ final class PostgresRowSequenceTests: XCTestCase { while currentTarget > AdaptiveRowBuffer.defaultBufferMinimum { // the buffer is filled up to currentTarget at that point, if we remove one row and add // one row it should shrink - XCTAssertEqual(dataSource.requestCount, expectedRequestCount) + #expect(dataSource.requestCount == expectedRequestCount) _ = try await rowIterator.next() expectedRequestCount += 1 - XCTAssertEqual(dataSource.requestCount, expectedRequestCount) + #expect(dataSource.requestCount == expectedRequestCount) stream.receive([[ByteBuffer(integer: Int64(1))], [ByteBuffer(integer: Int64(1))]]) let newTarget = currentTarget / 2 @@ -375,17 +385,18 @@ final class PostgresRowSequenceTests: XCTestCase { // consume all messages that are to much. for _ in 0..