From d97f26858003d87a1d2b52c0061b243035620996 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sun, 28 Feb 2021 13:23:47 +0100 Subject: [PATCH 001/292] Internal: Rename Command Contexts (#147) --- .../ConnectionStateMachine.swift | 10 ++++----- .../ExtendedQueryStateMachine.swift | 22 +++++++++---------- .../PrepareStatementStateMachine.swift | 14 ++++++------ .../PostgresNIO/New/PSQLChannelHandler.swift | 4 ++-- Sources/PostgresNIO/New/PSQLConnection.swift | 6 ++--- Sources/PostgresNIO/New/PSQLRows.swift | 2 +- Sources/PostgresNIO/New/PSQLTask.swift | 8 +++---- .../ConnectionStateMachineTests.swift | 2 +- .../ExtendedQueryStateMachineTests.swift | 6 ++--- .../PrepareStatementStateMachineTests.swift | 4 ++-- 10 files changed, 39 insertions(+), 39 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index e038f5ad..958cafd4 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -85,9 +85,9 @@ struct ConnectionStateMachine { // --- general actions case sendParseDescribeBindExecuteSync(query: String, binds: [PSQLEncodable]) case sendBindExecuteSync(statementName: String, binds: [PSQLEncodable]) - case failQuery(ExecuteExtendedQueryContext, with: PSQLError, cleanupContext: CleanUpContext?) - case succeedQuery(ExecuteExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column]) - case succeedQueryNoRowsComming(ExecuteExtendedQueryContext, commandTag: String) + case failQuery(ExtendedQueryContext, with: PSQLError, cleanupContext: CleanUpContext?) + case succeedQuery(ExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column]) + case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) // --- streaming actions // actions if query has requested next row but we are waiting for backend @@ -100,8 +100,8 @@ struct ConnectionStateMachine { // Prepare statement actions case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(CreatePreparedStatementContext, with: PSQLBackendMessage.RowDescription?) - case failPreparedStatementCreation(CreatePreparedStatementContext, with: PSQLError, cleanupContext: CleanUpContext?) + case succeedPreparedStatementCreation(PrepareStatementContext, with: PSQLBackendMessage.RowDescription?) + case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError, cleanupContext: CleanUpContext?) // Close actions case sendCloseSync(CloseTarget) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 0fa054d2..566597a7 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -2,17 +2,17 @@ struct ExtendedQueryStateMachine { enum State { - case initialized(ExecuteExtendedQueryContext) - case parseDescribeBindExecuteSyncSent(ExecuteExtendedQueryContext) + case initialized(ExtendedQueryContext) + case parseDescribeBindExecuteSyncSent(ExtendedQueryContext) - case parseCompleteReceived(ExecuteExtendedQueryContext) - case parameterDescriptionReceived(ExecuteExtendedQueryContext) - case rowDescriptionReceived(ExecuteExtendedQueryContext, [PSQLBackendMessage.RowDescription.Column]) - case noDataMessageReceived(ExecuteExtendedQueryContext) + case parseCompleteReceived(ExtendedQueryContext) + case parameterDescriptionReceived(ExtendedQueryContext) + case rowDescriptionReceived(ExtendedQueryContext, [PSQLBackendMessage.RowDescription.Column]) + case noDataMessageReceived(ExtendedQueryContext) /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message - case bindCompleteReceived(ExecuteExtendedQueryContext) + case bindCompleteReceived(ExtendedQueryContext) case bufferingRows([PSQLBackendMessage.RowDescription.Column], CircularBuffer<[PSQLData]>, readOnEmpty: Bool) case waitingForNextRow([PSQLBackendMessage.RowDescription.Column], CircularBuffer<[PSQLData]>, EventLoopPromise) @@ -27,9 +27,9 @@ struct ExtendedQueryStateMachine { case sendBindExecuteSync(statementName: String, binds: [PSQLEncodable]) // --- general actions - case failQuery(ExecuteExtendedQueryContext, with: PSQLError) - case succeedQuery(ExecuteExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column]) - case succeedQueryNoRowsComming(ExecuteExtendedQueryContext, commandTag: String) + case failQuery(ExtendedQueryContext, with: PSQLError) + case succeedQuery(ExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column]) + case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) // --- streaming actions // actions if query has requested next row but we are waiting for backend @@ -46,7 +46,7 @@ struct ExtendedQueryStateMachine { var state: State - init(queryContext: ExecuteExtendedQueryContext) { + init(queryContext: ExtendedQueryContext) { self.state = .initialized(queryContext) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift index 2715b25a..98e18dbc 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift @@ -2,11 +2,11 @@ struct PrepareStatementStateMachine { enum State { - case initialized(CreatePreparedStatementContext) - case parseDescribeSent(CreatePreparedStatementContext) + case initialized(PrepareStatementContext) + case parseDescribeSent(PrepareStatementContext) - case parseCompleteReceived(CreatePreparedStatementContext) - case parameterDescriptionReceived(CreatePreparedStatementContext) + case parseCompleteReceived(PrepareStatementContext) + case parameterDescriptionReceived(PrepareStatementContext) case rowDescriptionReceived case noDataMessageReceived @@ -15,8 +15,8 @@ struct PrepareStatementStateMachine { enum Action { case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(CreatePreparedStatementContext, with: PSQLBackendMessage.RowDescription?) - case failPreparedStatementCreation(CreatePreparedStatementContext, with: PSQLError) + case succeedPreparedStatementCreation(PrepareStatementContext, with: PSQLBackendMessage.RowDescription?) + case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError) case read case wait @@ -24,7 +24,7 @@ struct PrepareStatementStateMachine { var state: State - init(createContext: CreatePreparedStatementContext) { + init(createContext: PrepareStatementContext) { self.state = .initialized(createContext) } diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index f3c2e274..67b0dced 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -417,7 +417,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { } private func succeedQueryWithRowStream( - _ queryContext: ExecuteExtendedQueryContext, + _ queryContext: ExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column], context: ChannelHandlerContext) { @@ -448,7 +448,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { } private func succeedQueryWithoutRowStream( - _ queryContext: ExecuteExtendedQueryContext, + _ queryContext: ExtendedQueryContext, commandTag: String, context: ChannelHandlerContext) { diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 391dd2f9..a49c9c15 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -131,7 +131,7 @@ final class PSQLConnection { return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) } let promise = self.channel.eventLoop.makePromise(of: PSQLRows.self) - let context = ExecuteExtendedQueryContext( + let context = ExtendedQueryContext( query: query, bind: bind, logger: logger, @@ -151,7 +151,7 @@ final class PSQLConnection { func prepareStatement(_ query: String, with name: String, logger: Logger) -> EventLoopFuture { let promise = self.channel.eventLoop.makePromise(of: PSQLBackendMessage.RowDescription?.self) - let context = CreatePreparedStatementContext( + let context = PrepareStatementContext( name: name, query: query, logger: logger, @@ -170,7 +170,7 @@ final class PSQLConnection { return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) } let promise = self.channel.eventLoop.makePromise(of: PSQLRows.self) - let context = ExecuteExtendedQueryContext( + let context = ExtendedQueryContext( preparedStatement: preparedStatement, bind: bind, logger: logger, diff --git a/Sources/PostgresNIO/New/PSQLRows.swift b/Sources/PostgresNIO/New/PSQLRows.swift index 62bae59e..6a25b90b 100644 --- a/Sources/PostgresNIO/New/PSQLRows.swift +++ b/Sources/PostgresNIO/New/PSQLRows.swift @@ -25,7 +25,7 @@ final class PSQLRows { private let jsonDecoder: PSQLJSONDecoder init(rowDescription: [PSQLBackendMessage.RowDescription.Column], - queryContext: ExecuteExtendedQueryContext, + queryContext: ExtendedQueryContext, eventLoop: EventLoop, cancel: @escaping () -> (), next: @escaping () -> EventLoopFuture) diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index edd97bdb..07ea10ca 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -1,6 +1,6 @@ enum PSQLTask { - case extendedQuery(ExecuteExtendedQueryContext) - case preparedStatement(CreatePreparedStatementContext) + case extendedQuery(ExtendedQueryContext) + case preparedStatement(PrepareStatementContext) case closeCommand(CloseCommandContext) func failWithError(_ error: PSQLError) { @@ -15,7 +15,7 @@ enum PSQLTask { } } -final class ExecuteExtendedQueryContext { +final class ExtendedQueryContext { enum Query { case unnamed(String) case preparedStatement(name: String, rowDescription: PSQLBackendMessage.RowDescription?) @@ -58,7 +58,7 @@ final class ExecuteExtendedQueryContext { } -final class CreatePreparedStatementContext { +final class PrepareStatementContext { let name: String let query: String let logger: Logger diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 123a3957..c0b5c3a8 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -103,7 +103,7 @@ class ConnectionStateMachineTests: XCTestCase { let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRows.self) var state = ConnectionStateMachine() - let extendedQueryContext = ExecuteExtendedQueryContext( + let extendedQueryContext = ExtendedQueryContext( query: "Select version()", bind: [], logger: .psqlTest, diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 4f32541e..08005156 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -10,7 +10,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let promise = EmbeddedEventLoop().makePromise(of: PSQLRows.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "DELETE FROM table WHERE id=$0" - let queryContext = ExecuteExtendedQueryContext(query: query, bind: [1], logger: logger, jsonDecoder: JSONDecoder(), promise: promise) + let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, jsonDecoder: JSONDecoder(), promise: promise) XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [1])) XCTAssertEqual(state.parseCompleteReceived(), .wait) @@ -28,7 +28,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let queryPromise = EmbeddedEventLoop().makePromise(of: PSQLRows.self) queryPromise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "SELECT version()" - let queryContext = ExecuteExtendedQueryContext(query: query, bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: queryPromise) + let queryContext = ExtendedQueryContext(query: query, bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: queryPromise) XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) @@ -59,7 +59,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let promise = EmbeddedEventLoop().makePromise(of: PSQLRows.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "DELETE FROM table WHERE id=$0" - let queryContext = ExecuteExtendedQueryContext(query: query, bind: [1], logger: logger, jsonDecoder: JSONDecoder(), promise: promise) + let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, jsonDecoder: JSONDecoder(), promise: promise) XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [1])) XCTAssertEqual(state.parseCompleteReceived(), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 7b7862d0..50870b15 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -11,7 +11,7 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"SELECT id FROM users WHERE id = $1 "# - let prepareStatementContext = CreatePreparedStatementContext( + let prepareStatementContext = PrepareStatementContext( name: name, query: query, logger: .psqlTest, promise: promise) XCTAssertEqual(state.enqueue(task: .preparedStatement(prepareStatementContext)), @@ -36,7 +36,7 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# - let prepareStatementContext = CreatePreparedStatementContext( + let prepareStatementContext = PrepareStatementContext( name: name, query: query, logger: .psqlTest, promise: promise) XCTAssertEqual(state.enqueue(task: .preparedStatement(prepareStatementContext)), From 33eea144f465a48b17f4c63573c40d33824fdde4 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Apr 2021 15:44:51 +0200 Subject: [PATCH 002/292] Ignore errors when closing the connection (#151) --- .../ConnectionStateMachine.swift | 27 ++++++++++++++----- .../PostgresNIO/New/PSQLChannelHandler.swift | 4 +-- Sources/PostgresNIO/New/PSQLConnection.swift | 4 --- Sources/PostgresNIO/New/PSQLRows.swift | 4 ++- .../ConnectionStateMachineTests.swift | 13 +++++++++ .../ConnectionAction+TestUtils.swift | 2 ++ 6 files changed, 41 insertions(+), 13 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 958cafd4..1c3629b7 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -313,8 +313,7 @@ struct ConnectionStateMachine { .waitingToStartAuthentication, .authenticated, .readyForQuery, - .error, - .closing: + .error: return self.closeConnectionAndCleanup(.server(errorMessage)) case .authenticating(var authState): if authState.isComplete { @@ -352,6 +351,13 @@ struct ConnectionStateMachine { machine.state = .prepareStatement(preparedState, connectionContext) return machine.modify(with: action) } + case .closing: + // If the state machine is in state `.closing`, the connection shutdown was initiated + // by the client. This means a `TERMINATE` message has already been sent and the + // connection close was passed on to the channel. Therefore we await a channelInactive + // as the next event. + // Since a connection close was already issued, we should keep cool and just wait. + return .wait case .initialized, .closed: preconditionFailure("We should not receive server errors if we are not connected") case .modifying: @@ -367,8 +373,7 @@ struct ConnectionStateMachine { .sslHandlerAdded, .waitingToStartAuthentication, .authenticated, - .readyForQuery, - .closing: + .readyForQuery: return self.closeConnectionAndCleanup(error) case .authenticating(var authState): let action = authState.errorHappened(error) @@ -396,6 +401,16 @@ struct ConnectionStateMachine { } case .error: return .wait + case .closing: + // If the state machine is in state `.closing`, the connection shutdown was initiated + // by the client. This means a `TERMINATE` message has already been sent and the + // connection close was passed on to the channel. Therefore we await a channelInactive + // as the next event. + // For some reason Azure Postgres does not end ssl cleanly when terminating the + // connection. More documentation can be found in the issue: + // https://github.com/vapor/postgres-nio/issues/150 + // Since a connection close was already issued, we should keep cool and just wait. + return .wait case .closed: return self.closeConnectionAndCleanup(error) @@ -1108,8 +1123,8 @@ struct AuthContext: Equatable, CustomDebugStringConvertible { var debugDescription: String { """ - (username: \(String(reflecting: self.username)), \ - password: \(self.password != nil ? String(reflecting: self.password!) : "nil"), \ + AuthContext(username: \(String(reflecting: self.username)), \ + password: \(self.password != nil ? "********" : "nil"), \ database: \(self.database != nil ? String(reflecting: self.database!) : "nil")) """ } diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index 67b0dced..aa2128ec 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -75,7 +75,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { } func errorCaught(context: ChannelHandlerContext, error: Error) { - self.logger.error("Channel error caught.", metadata: [.error: "\(error)"]) + self.logger.debug("Channel error caught.", metadata: [.error: "\(error)"]) let action = self.state.errorHappened(.channel(underlying: error)) self.run(action, with: context) } @@ -470,7 +470,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { _ cleanup: ConnectionStateMachine.ConnectionAction.CleanUpContext, context: ChannelHandlerContext) { - self.logger.error("Channel error caught. Closing connection.", metadata: [.error: "\(cleanup.error)"]) + self.logger.debug("Cleaning up and closing connection.", metadata: [.error: "\(cleanup.error)"]) // 1. fail all tasks cleanup.tasks.forEach { task in diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index a49c9c15..0bbea948 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -140,10 +140,6 @@ final class PSQLConnection { self.channel.write(PSQLTask.extendedQuery(context), promise: nil) - // success is logged in PSQLQuery - promise.futureResult.whenFailure { error in - logger.error("Query failed", metadata: [.error: "\(error)"]) - } return promise.futureResult } diff --git a/Sources/PostgresNIO/New/PSQLRows.swift b/Sources/PostgresNIO/New/PSQLRows.swift index 6a25b90b..23efe393 100644 --- a/Sources/PostgresNIO/New/PSQLRows.swift +++ b/Sources/PostgresNIO/New/PSQLRows.swift @@ -99,7 +99,9 @@ final class PSQLRows { } internal func noticeReceived(_ notice: PSQLBackendMessage.NoticeResponse) { - self.logger.notice("Notice Received \(notice)") + self.logger.debug("Notice Received", metadata: [ + .notice: "\(notice)" + ]) } internal func finalForward(_ finalForward: Result<(CircularBuffer<[PSQLData]>, commandTag: String), PSQLError>?) { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index c0b5c3a8..8569d7c3 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -91,6 +91,19 @@ class ConnectionStateMachineTests: XCTestCase { .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.readyForQuery(.idle)), closePromise: nil))) } + func testErrorIsIgnoredWhenClosingConnection() { + // test ignore unclean shutdown when closing connection + var stateIgnoreChannelError = ConnectionStateMachine(.closing) + + XCTAssertEqual(stateIgnoreChannelError.errorHappened(.channel(underlying: NIOSSLError.uncleanShutdown)), .wait) + XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive) + + // test ignore any other error when closing connection + + var stateIgnoreErrorMessage = ConnectionStateMachine(.closing) + XCTAssertEqual(stateIgnoreErrorMessage.errorReceived(.init(fields: [:])), .wait) + XCTAssertEqual(stateIgnoreErrorMessage.closed(), .fireChannelInactive) + } func testFailQueuedQueriesOnAuthenticationFailure() throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index c4f1af1f..d99c4280 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -74,6 +74,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { return lhsName == rhsName && lhsQuery == rhsQuery case (.succeedPreparedStatementCreation(let lhsContext, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsContext, let rhsRowDescription)): return lhsContext === rhsContext && lhsRowDescription == rhsRowDescription + case (.fireChannelInactive, .fireChannelInactive): + return true default: return false } From a095fc7ea77a96e8d4577b80af412a24aa3299b9 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 6 May 2021 17:09:37 +0200 Subject: [PATCH 003/292] Add platform requirements for iOS, watchOS and tvOS (#154) To use postgres-nio from iOS, watchOS and tvOS the platform requirements need to be set correctly. --- Package.swift | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index 0f017029..f8b5e3ea 100644 --- a/Package.swift +++ b/Package.swift @@ -4,7 +4,10 @@ import PackageDescription let package = Package( name: "postgres-nio", platforms: [ - .macOS(.v10_15) + .macOS(.v10_15), + .iOS(.v13), + .watchOS(.v6), + .tvOS(.v13), ], products: [ .library(name: "PostgresNIO", targets: ["PostgresNIO"]), From 8c611cec5a82a4f8eb493094507c446c982e0740 Mon Sep 17 00:00:00 2001 From: Jaap Wijnen Date: Fri, 7 May 2021 15:33:56 +0200 Subject: [PATCH 004/292] Update api-docs.yml (#156) --- .github/workflows/api-docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/api-docs.yml b/.github/workflows/api-docs.yml index f27346af..d521498e 100644 --- a/.github/workflows/api-docs.yml +++ b/.github/workflows/api-docs.yml @@ -2,7 +2,7 @@ name: deploy-api-docs on: push: branches: - - master + - main jobs: deploy: From 290f3c13db8acdfaa832d118370ba7bf3195c9bf Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 10 May 2021 18:54:39 +0200 Subject: [PATCH 005/292] Split out IntegrationTests into separate testTarget (#157) This splits out all the integration tests into a separate test target. This makes a number of things easier in day to day development: - Unit tests can be run without setting up a test database. - We can observe the test coverage in our unit tests only. - This allows us to decouple testing the language vs. testing the functionality against a database. --- Package.swift | 4 + .../PSQLIntegrationTests.swift} | 0 .../PerformanceTests.swift | 0 .../PostgresNIOTests.swift | 8 +- Tests/IntegrationTests/Utilities.swift | 75 ++++++++++++++ .../New/Extensions/LoggingUtils.swift | 7 -- Tests/PostgresNIOTests/Utilities.swift | 98 +------------------ 7 files changed, 87 insertions(+), 105 deletions(-) rename Tests/{PostgresNIOTests/New/IntegrationTests.swift => IntegrationTests/PSQLIntegrationTests.swift} (100%) rename Tests/{PostgresNIOTests => IntegrationTests}/PerformanceTests.swift (100%) rename Tests/{PostgresNIOTests => IntegrationTests}/PostgresNIOTests.swift (99%) create mode 100644 Tests/IntegrationTests/Utilities.swift diff --git a/Package.swift b/Package.swift index f8b5e3ea..b49f658e 100644 --- a/Package.swift +++ b/Package.swift @@ -33,5 +33,9 @@ let package = Package( .target(name: "PostgresNIO"), .product(name: "NIOTestUtils", package: "swift-nio"), ]), + .testTarget(name: "IntegrationTests", dependencies: [ + .target(name: "PostgresNIO"), + .product(name: "NIOTestUtils", package: "swift-nio"), + ]), ] ) diff --git a/Tests/PostgresNIOTests/New/IntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift similarity index 100% rename from Tests/PostgresNIOTests/New/IntegrationTests.swift rename to Tests/IntegrationTests/PSQLIntegrationTests.swift diff --git a/Tests/PostgresNIOTests/PerformanceTests.swift b/Tests/IntegrationTests/PerformanceTests.swift similarity index 100% rename from Tests/PostgresNIOTests/PerformanceTests.swift rename to Tests/IntegrationTests/PerformanceTests.swift diff --git a/Tests/PostgresNIOTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift similarity index 99% rename from Tests/PostgresNIOTests/PostgresNIOTests.swift rename to Tests/IntegrationTests/PostgresNIOTests.swift index e8162e9f..d4095658 100644 --- a/Tests/PostgresNIOTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1050,7 +1050,9 @@ final class PostgresNIOTests: XCTestCase { defer { XCTAssertNoThrow( try conn?.close().wait() ) } let binds = [PostgresData].init(repeating: .null, count: Int(Int16.max) + 1) XCTAssertThrowsError(try conn?.query("SELECT version()", binds).wait()) { error in - XCTAssertEqual(error as? PSQLError, .tooManyParameters) + guard case .tooManyParameters = (error as? PSQLError)?.base else { + return XCTFail("Unexpected error: \(error)") + } } } @@ -1189,10 +1191,6 @@ final class PostgresNIOTests: XCTestCase { } } -func env(_ name: String) -> String? { - getenv(name).flatMap { String(cString: $0) } -} - let isLoggingConfigured: Bool = { LoggingSystem.bootstrap { label in var handler = StreamLogHandler.standardOutput(label: label) diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift new file mode 100644 index 00000000..3c762219 --- /dev/null +++ b/Tests/IntegrationTests/Utilities.swift @@ -0,0 +1,75 @@ +import PostgresNIO +import XCTest +import Logging +#if canImport(Darwin) +import Darwin.C +#else +import Glibc +#endif + +extension PostgresConnection { + static func address() throws -> SocketAddress { + try .makeAddressResolvingHost(env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432) + } + + static func testUnauthenticated(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { + var logger = Logger(label: "postgres.connection.test") + logger.logLevel = logLevel + do { + return connect(to: try address(), logger: logger, on: eventLoop) + } catch { + return eventLoop.makeFailedFuture(error) + } + } + + static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { + return testUnauthenticated(on: eventLoop, logLevel: logLevel).flatMap { conn in + return conn.authenticate( + username: env("POSTGRES_USER") ?? "vapor_username", + database: env("POSTGRES_DB") ?? "vapor_database", + password: env("POSTGRES_PASSWORD") ?? "vapor_password" + ).map { + return conn + }.flatMapError { error in + conn.close().flatMapThrowing { + throw error + } + } + } + } +} + +extension Logger { + static var psqlTest: Logger { + var logger = Logger(label: "psql.test") + logger.logLevel = .info + return logger + } +} + +func env(_ name: String) -> String? { + getenv(name).flatMap { String(cString: $0) } +} + +extension XCTestCase { + + public static var shouldRunLongRunningTests: Bool { + // The env var must be set and have the value `"true"`, `"1"`, or `"yes"` (case-insensitive). + // For the sake of sheer annoying pedantry, values like `"2"` are treated as false. + guard let rawValue = env("POSTGRES_LONG_RUNNING_TESTS") else { return false } + if let boolValue = Bool(rawValue) { return boolValue } + if let intValue = Int(rawValue) { return intValue == 1 } + return rawValue.lowercased() == "yes" + } + + public static var shouldRunPerformanceTests: Bool { + // Same semantics as above. Any present non-truthy value will explicitly disable performance + // tests even if they would've overwise run in the current configuration. + let defaultValue = !_isDebugAssertConfiguration() // default to not running in debug builds + + guard let rawValue = env("POSTGRES_PERFORMANCE_TESTS") else { return defaultValue } + if let boolValue = Bool(rawValue) { return boolValue } + if let intValue = Int(rawValue) { return intValue == 1 } + return rawValue.lowercased() == "yes" + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift b/Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift index 610d8f10..fdada802 100644 --- a/Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift @@ -1,9 +1,2 @@ import Logging -extension Logger { - static var psqlTest: Logger { - var logger = Logger(label: "psql.test") - logger.logLevel = .info - return logger - } -} diff --git a/Tests/PostgresNIOTests/Utilities.swift b/Tests/PostgresNIOTests/Utilities.swift index 66e9949b..610d8f10 100644 --- a/Tests/PostgresNIOTests/Utilities.swift +++ b/Tests/PostgresNIOTests/Utilities.swift @@ -1,97 +1,9 @@ import Logging -import PostgresNIO -import XCTest -extension PostgresConnection { - static func address() throws -> SocketAddress { - try .makeAddressResolvingHost( env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432) +extension Logger { + static var psqlTest: Logger { + var logger = Logger(label: "psql.test") + logger.logLevel = .info + return logger } - - static func testUnauthenticated(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - var logger = Logger(label: "postgres.connection.test") - logger.logLevel = logLevel - do { - return connect(to: try address(), logger: logger, on: eventLoop) - } catch { - return eventLoop.makeFailedFuture(error) - } - } - - static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - return testUnauthenticated(on: eventLoop, logLevel: logLevel).flatMap { conn in - return conn.authenticate( - username: env("POSTGRES_USER") ?? "vapor_username", - database: env("POSTGRES_DB") ?? "vapor_database", - password: env("POSTGRES_PASSWORD") ?? "vapor_password" - ).map { - return conn - }.flatMapError { error in - conn.close().flatMapThrowing { - throw error - } - } - } - } -} - -extension XCTestCase { - - public static var shouldRunLongRunningTests: Bool { - // The env var must be set and have the value `"true"`, `"1"`, or `"yes"` (case-insensitive). - // For the sake of sheer annoying pedantry, values like `"2"` are treated as false. - guard let rawValue = ProcessInfo.processInfo.environment["POSTGRES_LONG_RUNNING_TESTS"] else { return false } - if let boolValue = Bool(rawValue) { return boolValue } - if let intValue = Int(rawValue) { return intValue == 1 } - return rawValue.lowercased() == "yes" - } - - public static var shouldRunPerformanceTests: Bool { - // Same semantics as above. Any present non-truthy value will explicitly disable performance - // tests even if they would've overwise run in the current configuration. - let defaultValue = !_isDebugAssertConfiguration() // default to not running in debug builds - - guard let rawValue = ProcessInfo.processInfo.environment["POSTGRES_PERFORMANCE_TESTS"] else { return defaultValue } - if let boolValue = Bool(rawValue) { return boolValue } - if let intValue = Int(rawValue) { return intValue == 1 } - return rawValue.lowercased() == "yes" - } - -} - - -// 1247.typisdefined: 0x01 (BOOLEAN) -// 1247.typbasetype: 0x00000000 (OID) -// 1247.typnotnull: 0x00 (BOOLEAN) -// 1247.typcategory: 0x42 (CHAR) -// 1247.typname: 0x626f6f6c (NAME) -// 1247.typbyval: 0x01 (BOOLEAN) -// 1247.typrelid: 0x00000000 (OID) -// 1247.typalign: 0x63 (CHAR) -// 1247.typndims: 0x00000000 (INTEGER) -// 1247.typacl: null -// 1247.typsend: 0x00000985 (REGPROC) -// 1247.typmodout: 0x00000000 (REGPROC) -// 1247.typstorage: 0x70 (CHAR) -// 1247.typispreferred: 0x01 (BOOLEAN) -// 1247.typinput: 0x000004da (REGPROC) -// 1247.typoutput: 0x000004db (REGPROC) -// 1247.typlen: 0x0001 (SMALLINT) -// 1247.typcollation: 0x00000000 (OID) -// 1247.typdefaultbin: null -// 1247.typelem: 0x00000000 (OID) -// 1247.typnamespace: 0x0000000b (OID) -// 1247.typtype: 0x62 (CHAR) -// 1247.typowner: 0x0000000a (OID) -// 1247.typdefault: null -// 1247.typtypmod: 0xffffffff (INTEGER) -// 1247.typarray: 0x000003e8 (OID) -// 1247.typreceive: 0x00000984 (REGPROC) -// 1247.typmodin: 0x00000000 (REGPROC) -// 1247.typanalyze: 0x00000000 (REGPROC) -// 1247.typdelim: 0x2c (CHAR) -struct PGType: Decodable { - var typname: String - var typnamespace: UInt32 - var typowner: UInt32 - var typlen: Int16 } From 6629d63420e0c00bd3af0fd4551d78e6cbaf7e33 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 11 May 2021 13:16:41 +0200 Subject: [PATCH 006/292] Use sync pipeline operations (#152) --- Package.swift | 6 ++-- .../PostgresNIO/New/PSQLChannelHandler.swift | 28 +++++++++---------- Sources/PostgresNIO/New/PSQLConnection.swift | 21 +++++++------- .../New/PSQLChannelHandlerTests.swift | 28 +++++++++++++------ 4 files changed, 45 insertions(+), 38 deletions(-) diff --git a/Package.swift b/Package.swift index b49f658e..86dfd2d1 100644 --- a/Package.swift +++ b/Package.swift @@ -13,11 +13,11 @@ let package = Package( .library(name: "PostgresNIO", targets: ["PostgresNIO"]), ], dependencies: [ + .package(url: "https://github.com/apple/swift-nio.git", from: "2.28.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.12.0"), .package(url: "https://github.com/apple/swift-crypto.git", from: "1.0.0"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), ], targets: [ .target(name: "PostgresNIO", dependencies: [ diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index aa2128ec..84819d24 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -20,18 +20,18 @@ final class PSQLChannelHandler: ChannelDuplexHandler { } private var currentQuery: PSQLRows? private let authentificationConfiguration: PSQLConnection.Configuration.Authentication? - private let enableSSLCallback: ((Channel) -> EventLoopFuture)? + private let configureSSLCallback: ((Channel) throws -> Void)? /// this delegate should only be accessed on the connections `EventLoop` weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate? init(authentification: PSQLConnection.Configuration.Authentication?, logger: Logger, - enableSSLCallback: ((Channel) -> EventLoopFuture)? = nil) + configureSSLCallback: ((Channel) throws -> Void)?) { self.state = ConnectionStateMachine() self.authentificationConfiguration = authentification - self.enableSSLCallback = enableSSLCallback + self.configureSSLCallback = configureSSLCallback self.logger = logger } @@ -40,11 +40,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler { init(authentification: PSQLConnection.Configuration.Authentication?, state: ConnectionStateMachine = .init(.initialized), logger: Logger = .psqlNoOpLogger, - enableSSLCallback: ((Channel) -> EventLoopFuture)? = nil) + configureSSLCallback: ((Channel) throws -> Void)?) { self.state = state self.authentificationConfiguration = authentification - self.enableSSLCallback = enableSSLCallback + self.configureSSLCallback = configureSSLCallback self.logger = logger } #endif @@ -302,7 +302,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { // MARK: - Private Methods - private func connected(context: ChannelHandlerContext) { - let action = self.state.connected(requireTLS: self.enableSSLCallback != nil) + let action = self.state.connected(requireTLS: self.configureSSLCallback != nil) self.run(action, with: context) } @@ -310,15 +310,13 @@ final class PSQLChannelHandler: ChannelDuplexHandler { private func establishSSLConnection(context: ChannelHandlerContext) { // This method must only be called, if we signalized the StateMachine before that we are // able to setup a SSL connection. - self.enableSSLCallback!(context.channel).whenComplete { result in - switch result { - case .success: - let action = self.state.sslHandlerAdded() - self.run(action, with: context) - case .failure(let error): - let action = self.state.errorHappened(.failedToAddSSLHandler(underlying: error)) - self.run(action, with: context) - } + do { + try self.configureSSLCallback!(context.channel) + let action = self.state.sslHandlerAdded() + self.run(action, with: context) + } catch { + let action = self.state.errorHappened(.failedToAddSSLHandler(underlying: error)) + self.run(action, with: context) } } diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 0bbea948..5869334b 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -215,17 +215,16 @@ final class PSQLConnection { .channelInitializer { channel in let decoder = ByteToMessageHandler(PSQLBackendMessage.Decoder()) - var enableSSLCallback: ((Channel) -> EventLoopFuture)? = nil + var configureSSLCallback: ((Channel) throws -> ())? = nil if let tlsConfiguration = configuration.tlsConfiguration { - enableSSLCallback = { channel in - channel.eventLoop.submit { - let sslContext = try NIOSSLContext(configuration: tlsConfiguration) - return try NIOSSLClientHandler( - context: sslContext, - serverHostname: configuration.sslServerHostname) - }.flatMap { sslHandler in - channel.pipeline.addHandler(sslHandler, position: .before(decoder)) - } + configureSSLCallback = { channel in + channel.eventLoop.assertInEventLoop() + + let sslContext = try NIOSSLContext(configuration: tlsConfiguration) + let sslHandler = try NIOSSLClientHandler( + context: sslContext, + serverHostname: configuration.sslServerHostname) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(decoder)) } } @@ -235,7 +234,7 @@ final class PSQLConnection { PSQLChannelHandler( authentification: configuration.authentication, logger: logger, - enableSSLCallback: enableSSLCallback), + configureSSLCallback: configureSSLCallback), PSQLEventsHandler(logger: logger) ]) } diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index a1f49d19..929337ba 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -1,5 +1,6 @@ import XCTest import NIO +import NIOTLS @testable import PostgresNIO class PSQLChannelHandlerTests: XCTestCase { @@ -8,7 +9,7 @@ class PSQLChannelHandlerTests: XCTestCase { func testHandlerAddedWithoutSSL() { let config = self.testConnectionConfiguration() - let handler = PSQLChannelHandler(authentification: config.authentication) + let handler = PSQLChannelHandler(authentification: config.authentication, configureSSLCallback: nil) let embedded = EmbeddedChannel(handler: handler) defer { XCTAssertNoThrow(try embedded.finish()) } @@ -35,7 +36,6 @@ class PSQLChannelHandlerTests: XCTestCase { var addSSLCallbackIsHit = false let handler = PSQLChannelHandler(authentification: config.authentication) { channel in addSSLCallbackIsHit = true - return channel.eventLoop.makeSucceededFuture(()) } let embedded = EmbeddedChannel(handler: handler) @@ -48,14 +48,24 @@ class PSQLChannelHandlerTests: XCTestCase { XCTAssertEqual(request.code, 80877103) - // first we need to add an encoder, because NIOSSLHandler can only - // operate on ByteBuffer - let future = embedded.pipeline.addHandlers(MessageToByteHandler(PSQLFrontendMessage.Encoder.forTests), position: .first) - XCTAssertNoThrow(try future.wait()) XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.sslSupported)) // a NIOSSLHandler has been added, after it SSL had been negotiated XCTAssertTrue(addSSLCallbackIsHit) + + // signal that the ssl connection has been established + embedded.pipeline.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: "")) + + // startup message should be issued + var maybeStartupMessage: PSQLFrontendMessage? + XCTAssertNoThrow(maybeStartupMessage = try embedded.readOutbound(as: PSQLFrontendMessage.self)) + guard case .startup(let startupMessage) = maybeStartupMessage else { + return XCTFail("Unexpected message") + } + + XCTAssertEqual(startupMessage.parameters.user, config.authentication?.username) + XCTAssertEqual(startupMessage.parameters.database, config.authentication?.database) + XCTAssertEqual(startupMessage.parameters.replication, .false) } func testSSLUnsupportedClosesConnection() { @@ -64,7 +74,7 @@ class PSQLChannelHandlerTests: XCTestCase { let handler = PSQLChannelHandler(authentification: config.authentication) { channel in XCTFail("This callback should never be exectuded") - return channel.eventLoop.makeFailedFuture(PSQLError.sslUnsupported) + throw PSQLError.sslUnsupported } let embedded = EmbeddedChannel(handler: handler) let eventHandler = TestEventHandler() @@ -94,7 +104,7 @@ class PSQLChannelHandlerTests: XCTestCase { database: config.authentication?.database ) let state = ConnectionStateMachine(.waitingToStartAuthentication) - let handler = PSQLChannelHandler(authentification: config.authentication, state: state) + let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil) let embedded = EmbeddedChannel(handler: handler) embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) @@ -119,7 +129,7 @@ class PSQLChannelHandlerTests: XCTestCase { database: config.authentication?.database ) let state = ConnectionStateMachine(.waitingToStartAuthentication) - let handler = PSQLChannelHandler(authentification: config.authentication, state: state) + let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil) let embedded = EmbeddedChannel(handler: handler) embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) From 1140dd9fefc9050f83c60a85dad2dbb06c041e90 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 11 May 2021 13:20:04 +0200 Subject: [PATCH 007/292] Add one Auth and one SSL test case (#159) Adding two test cases as a by product of some other work... - Add a test case for instances in which adding a `SSLHandler` fails - Add a test case for instances in which a md5 password requested but none is provided. --- .../AuthenticationStateMachineTests.swift | 10 ++++++++++ .../ConnectionStateMachineTests.swift | 10 ++++++++++ .../PostgresNIOTests/New/PSQLBackendMessageTests.swift | 2 ++ 3 files changed, 22 insertions(+) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index c590a934..a1cfbb5c 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -22,6 +22,16 @@ class AuthenticationStateMachineTests: XCTestCase { XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) } + func testAuthenticateMD5WithoutPassword() { + let authContext = AuthContext(username: "test", password: nil, database: "test") + var state = ConnectionStateMachine(.waitingToStartAuthentication) + let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) + + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .authMechanismRequiresPassword, closePromise: nil))) + } + func testAuthenticateOkAfterStartUpWithoutAuthChallenge() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(.waitingToStartAuthentication) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 8569d7c3..b2ee2652 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -25,6 +25,16 @@ class ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) } + func testSSLStartupFailHandler() { + struct SSLHandlerAddError: Error, Equatable {} + + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(requireTLS: true), .sendSSLRequest) + XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection) + let failError: PSQLError = .failedToAddSSLHandler(underlying: SSLHandlerAddError()) + XCTAssertEqual(state.errorHappened(failError), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) + } + func testSSLStartupSSLUnsupported() { var state = ConnectionStateMachine() diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index 6968fd32..717fa455 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -269,6 +269,8 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertEqual("\(PSQLBackendMessage.authentication(.sspi))", ".authentication(.sspi)") + XCTAssertEqual("\(PSQLBackendMessage.parameterStatus(.init(parameter: "foo", value: "bar")))", + #".parameterStatus(parameter: "foo", value: "bar")"#) XCTAssertEqual("\(PSQLBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567)))", ".backendKeyData(processID: 1234, secretKey: 4567)") From 245d12246a9d72c88cc1d8e9cf6d2cb4e0b9fa7a Mon Sep 17 00:00:00 2001 From: Tim Condon <0xTim@users.noreply.github.com> Date: Wed, 9 Jun 2021 17:34:50 +0100 Subject: [PATCH 008/292] Fix API docs link. Resolves #163 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a9f16691..a60c9802 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ This package has no additional system dependencies. ## API Docs -Check out the [PostgresNIO API docs](https://api.vapor.codes/postgres-nio/master/PostgresNIO/) for a detailed look at all of the classes, structs, protocols, and more. +Check out the [PostgresNIO API docs](https://api.vapor.codes/postgres-nio/main/PostgresNIO/) for a detailed look at all of the classes, structs, protocols, and more. ## Getting Started From 8527fefa46a840a12308bc5c9f50885dce3146c2 Mon Sep 17 00:00:00 2001 From: Tim Condon <0xTim@users.noreply.github.com> Date: Wed, 16 Jun 2021 11:33:43 +0100 Subject: [PATCH 009/292] Add link to our SECURITY.md (#164) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index a60c9802..99530a85 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,10 @@ PostgresNIO supports the following platforms: - Ubuntu 16.04+ - macOS 10.15+ +### Secrurity + +Please see [SECURITY.md](https://github.com/vapor/.github/blob/main/SECURITY.md) for details on the security process. + ## Overview PostgresNIO is a client package for connecting to, authorizing, and querying a PostgreSQL server. At the heart of this module are NIO channel handlers for parsing and serializing messages in PostgreSQL's proprietary wire protocol. These channel handlers are combined in a request / response style connection type that provides a convenient, client-like interface for performing queries. From e18b7f899e4e48e9e93497f53d886c470535218c Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 22 Jul 2021 11:27:58 +0200 Subject: [PATCH 010/292] Use PSQLFormat when encoding and decoding (#158) Postgres supports two encodings on the wire text (0) and binary (1). Up until now the new `PSQLDecodable` only supported binary encoding. - This PR does not change any public API - Rename `PSQLFormatCode` to `PSQLFormat` - Add `PSQLFormat` as parameter to the decode function. - Add `psqlFormat: PSQLFormat` as protocol requirement for `PSQLEncodable` - All existing types encode into and decode from `.binary` format only (we can change this with follow up PRs) Co-authored-by: George Barnett --- .../PostgresConnection+Database.swift | 2 +- .../PostgresDatabase+PreparedQuery.swift | 2 +- .../ExtendedQueryStateMachine.swift | 19 ++++- .../New/Data/Array+PSQLCodable.swift | 13 ++- .../New/Data/Bool+PSQLCodable.swift | 42 ++++++++-- .../New/Data/Bytes+PSQLCodable.swift | 20 ++++- .../New/Data/Date+PSQLCodable.swift | 6 +- .../New/Data/Float+PSQLCodable.swift | 34 ++++++-- .../New/Data/Int+PSQLCodable.swift | 84 +++++++++++++------ .../New/Data/JSON+PSQLCodable.swift | 13 +-- .../New/Data/Optional+PSQLCodable.swift | 11 ++- .../Data/RawRepresentable+PSQLCodable.swift | 8 +- .../New/Data/String+PSQLCodable.swift | 16 ++-- .../New/Data/UUID+PSQLCodable.swift | 22 +++-- Sources/PostgresNIO/New/Messages/Bind.swift | 8 +- .../New/Messages/RowDescription.swift | 12 +-- Sources/PostgresNIO/New/PSQLCodable.swift | 17 +++- Sources/PostgresNIO/New/PSQLData.swift | 19 +++-- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 8 +- .../PSQLIntegrationTests.swift | 22 +++++ .../ExtendedQueryStateMachineTests.swift | 17 ++-- .../PrepareStatementStateMachineTests.swift | 2 +- .../New/Data/Array+PSQLCodableTests.swift | 18 ++-- .../New/Data/Bool+PSQLCodableTests.swift | 58 +++++++++++-- .../New/Data/Bytes+PSQLCodableTests.swift | 4 +- .../New/Data/Date+PSQLCodableTests.swift | 18 ++-- .../New/Data/Float+PSQLCodableTests.swift | 18 ++-- .../New/Data/JSON+PSQLCodableTests.swift | 23 ++++- .../New/Data/Optional+PSQLCodableTests.swift | 8 +- .../RawRepresentable+PSQLCodableTests.swift | 6 +- .../New/Data/String+PSQLCodableTests.swift | 10 +-- .../New/Data/UUID+PSQLCodableTests.swift | 26 +++--- .../New/Messages/BindTests.swift | 9 +- .../New/Messages/RowDescriptionTests.swift | 20 ++--- .../PostgresNIOTests/New/PSQLDataTests.swift | 2 +- 35 files changed, 438 insertions(+), 179 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift index 30c5009d..f6f99a1c 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift @@ -23,7 +23,7 @@ extension PostgresConnection: PostgresDatabase { dataType: PostgresDataType(UInt32(column.dataType.rawValue)), dataTypeSize: column.dataTypeSize, dataTypeModifier: column.dataTypeModifier, - formatCode: .init(psqlFormatCode: column.formatCode) + formatCode: .init(psqlFormatCode: column.format) ) } diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift index 327bef98..77f8be45 100644 --- a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift +++ b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift @@ -43,7 +43,7 @@ public struct PreparedQuery { dataType: PostgresDataType(UInt32(column.dataType.rawValue)), dataTypeSize: column.dataTypeSize, dataTypeModifier: column.dataTypeModifier, - formatCode: .init(psqlFormatCode: column.formatCode) + formatCode: .init(psqlFormatCode: column.format) ) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 566597a7..faf15626 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -114,7 +114,20 @@ struct ExtendedQueryStateMachine { } return self.avoidingStateMachineCoW { state -> Action in - state = .rowDescriptionReceived(queryContext, rowDescription.columns) + // In Postgres extended queries we receive the `rowDescription` before we send the + // `Bind` message. Well actually it's vice versa, but this is only true since we do + // pipelining during a query. + // + // In the actual protocol description we receive a rowDescription before the Bind + + // In Postgres extended queries we always request the response rows to be returned in + // `.binary` format. + let columns = rowDescription.columns.map { column -> PSQLBackendMessage.RowDescription.Column in + var column = column + column.format = .binary + return column + } + state = .rowDescriptionReceived(queryContext, columns) return .wait } } @@ -157,7 +170,7 @@ struct ExtendedQueryStateMachine { return self.avoidingStateMachineCoW { state -> Action in let row = dataRow.columns.enumerated().map { (index, buffer) in - PSQLData(bytes: buffer, dataType: columns[index].dataType) + PSQLData(bytes: buffer, dataType: columns[index].dataType, format: columns[index].format) } buffer.append(row) state = .bufferingRows(columns, buffer, readOnEmpty: readOnEmpty) @@ -174,7 +187,7 @@ struct ExtendedQueryStateMachine { return self.avoidingStateMachineCoW { state -> Action in precondition(buffer.isEmpty, "Expected the buffer to be empty") let row = dataRow.columns.enumerated().map { (index, buffer) in - PSQLData(bytes: buffer, dataType: columns[index].dataType) + PSQLData(bytes: buffer, dataType: columns[index].dataType, format: columns[index].format) } state = .bufferingRows(columns, buffer, readOnEmpty: false) diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index 1a3e5cae..607a0d5b 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -72,6 +72,10 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { Element.psqlArrayType } + var psqlFormat: PSQLFormat { + .binary + } + func encode(into buffer: inout ByteBuffer, context: PSQLEncodingContext) throws { // 0 if empty, 1 if not buffer.writeInteger(self.isEmpty ? 0 : 1, as: UInt32.self) @@ -98,7 +102,12 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { extension Array: PSQLDecodable where Element: PSQLArrayElement { - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Array { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Array { + guard case .binary = format else { + // currently we only support decoding arrays in binary format. + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + guard let isNotEmpty = buffer.readInteger(as: Int32.self), (0...1).contains(isNotEmpty) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } @@ -135,7 +144,7 @@ extension Array: PSQLDecodable where Element: PSQLArrayElement { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - let element = try Element.decode(from: &elementBuffer, type: elementType, context: context) + let element = try Element.decode(from: &elementBuffer, type: elementType, format: format, context: context) result.append(element) } diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift index 83d5ec0c..f67031c0 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift @@ -3,18 +3,42 @@ extension Bool: PSQLCodable { .bool } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Bool { - guard type == .bool, buffer.readableBytes == 1 else { + var psqlFormat: PSQLFormat { + .binary + } + + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Bool { + guard type == .bool else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - switch buffer.readInteger(as: UInt8.self) { - case .some(0): - return false - case .some(1): - return true - default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + switch format { + case .binary: + guard buffer.readableBytes == 1 else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + + switch buffer.readInteger(as: UInt8.self) { + case .some(0): + return false + case .some(1): + return true + default: + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + case .text: + guard buffer.readableBytes == 1 else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + + switch buffer.readInteger(as: UInt8.self) { + case .some(UInt8(ascii: "f")): + return false + case .some(UInt8(ascii: "t")): + return true + default: + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } } } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift index 34955cb3..3745e704 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift @@ -6,6 +6,10 @@ extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { .bytea } + var psqlFormat: PSQLFormat { + .binary + } + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeBytes(self) } @@ -16,12 +20,16 @@ extension ByteBuffer: PSQLCodable { .bytea } + var psqlFormat: PSQLFormat { + .binary + } + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { var copyOfSelf = self // dirty hack byteBuffer.writeBuffer(©OfSelf) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { return buffer } } @@ -30,12 +38,16 @@ extension Data: PSQLCodable { var psqlType: PSQLDataType { .bytea } - + + var psqlFormat: PSQLFormat { + .binary + } + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeBytes(self) } - - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { + + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { return buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! } } diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift index 9ac5bf70..a0e9efff 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift @@ -5,7 +5,11 @@ extension Date: PSQLCodable { .timestamptz } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { + var psqlFormat: PSQLFormat { + .binary + } + + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { switch type { case .timestamp, .timestamptz: guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index 505ba1b0..be9bc045 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -3,18 +3,27 @@ extension Float: PSQLCodable { .float4 } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Float { - switch type { - case .float4: + var psqlFormat: PSQLFormat { + .binary + } + + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Float { + switch (format, type) { + case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.readFloat() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return float - case .float8: + case (.binary, .float8): guard buffer.readableBytes == 8, let double = buffer.readDouble() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return Float(double) + case (.text, .float4), (.text, .float8): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Float(string) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + return value default: throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } @@ -30,18 +39,27 @@ extension Double: PSQLCodable { .float8 } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Double { - switch type { - case .float4: + var psqlFormat: PSQLFormat { + .binary + } + + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Double { + switch (format, type) { + case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.readFloat() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return Double(float) - case .float8: + case (.binary, .float8): guard buffer.readableBytes == 8, let double = buffer.readDouble() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return double + case (.text, .float4), (.text, .float8): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Double(string) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + return value default: throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift index 3fd11733..11c2c46c 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift @@ -3,8 +3,12 @@ extension UInt8: PSQLCodable { .char } + var psqlFormat: PSQLFormat { + .binary + } + // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { switch type { case .bpchar, .char: guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { @@ -29,14 +33,23 @@ extension Int16: PSQLCodable { .int2 } + var psqlFormat: PSQLFormat { + .binary + } + // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { - switch type { - case .int2: + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + switch (format, type) { + case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return value + case (.text, .int2): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Int16(string) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + return value default: throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } @@ -53,21 +66,28 @@ extension Int32: PSQLCodable { .int4 } + var psqlFormat: PSQLFormat { + .binary + } + // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { - switch type { - case .int2: + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + switch (format, type) { + case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - return Int32(value) - case .int4: + case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - return Int32(value) + case (.text, .int2), (.text, .int4): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Int32(string) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + return value default: throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } @@ -84,26 +104,32 @@ extension Int64: PSQLCodable { .int8 } + var psqlFormat: PSQLFormat { + .binary + } + // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { - switch type { - case .int2: + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + switch (format, type) { + case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - return Int64(value) - case .int4: + case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - return Int64(value) - case .int8: + case (.binary, .int8): guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int64.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - + return value + case (.text, .int2), (.text, .int4), (.text, .int8): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Int64(string) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } return value default: throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) @@ -128,26 +154,32 @@ extension Int: PSQLCodable { } } + var psqlFormat: PSQLFormat { + .binary + } + // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { - switch type { - case .int2: + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + switch (format, type) { + case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - return Int(value) - case .int4: + case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - return Int(value) - case .int8 where Int.bitWidth == 64: + case (.binary, .int8) where Int.bitWidth == 64: guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - + return value + case (.text, .int2), (.text, .int4), (.text, .int8): + guard let string = buffer.readString(length: buffer.readableBytes), let value = Int(string) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } return value default: throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) diff --git a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift index 52bbed22..8ca5f08c 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift @@ -9,15 +9,18 @@ extension PSQLCodable where Self: Codable { .jsonb } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { - switch type { - case .jsonb: + var psqlFormat: PSQLFormat { + .binary + } + + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + switch (format, type) { + case (.binary, .jsonb): guard JSONBVersionByte == buffer.readInteger(as: UInt8.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - return try context.jsonDecoder.decode(Self.self, from: buffer) - case .json: + case (.binary, .json), (.text, .jsonb), (.text, .json): return try context.jsonDecoder.decode(Self.self, from: buffer) default: throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index 28f1d919..0005f7f8 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -1,5 +1,5 @@ extension Optional: PSQLDecodable where Wrapped: PSQLDecodable { - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Optional { + static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Optional { preconditionFailure("This code path should never be hit.") // The code path for decoding an optional should be: // -> PSQLData.decode(as: String?.self) @@ -18,6 +18,15 @@ extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { } } + var psqlFormat: PSQLFormat { + switch self { + case .some(let value): + return value.psqlFormat + case .none: + return .binary + } + } + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") } diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift index 1d833ccf..f2096e77 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift @@ -3,8 +3,12 @@ extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { self.rawValue.psqlType } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self { - guard let rawValue = try? RawValue.decode(from: &buffer, type: type, context: context), + var psqlFormat: PSQLFormat { + self.rawValue.psqlFormat + } + + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + guard let rawValue = try? RawValue.decode(from: &buffer, type: type, format: format, context: context), let selfValue = Self.init(rawValue: rawValue) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift index 073b8502..9e325435 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift @@ -5,18 +5,24 @@ extension String: PSQLCodable { .text } + var psqlFormat: PSQLFormat { + .binary + } + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeString(self) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> String { - switch type { - case .varchar, .text, .name: + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> String { + switch (format, type) { + case (_, .varchar), + (_, .text), + (_, .name): // we can force unwrap here, since this method only fails if there are not enough // bytes available. return buffer.readString(length: buffer.readableBytes)! - case .uuid: - guard let uuid = try? UUID.decode(from: &buffer, type: .uuid, context: context) else { + case (_, .uuid): + guard let uuid = try? UUID.decode(from: &buffer, type: .uuid, format: format, context: context) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return uuid.uuidString diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift index 7cb66441..fcabd094 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift @@ -7,6 +7,10 @@ extension UUID: PSQLCodable { .uuid } + var psqlFormat: PSQLFormat { + .binary + } + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { let uuid = self.uuid byteBuffer.writeBytes([ @@ -17,15 +21,23 @@ extension UUID: PSQLCodable { ]) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> UUID { - switch type { - case .uuid: + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> UUID { + switch (format, type) { + case (.binary, .uuid): guard let uuid = buffer.readUUID() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return uuid - case .varchar, .text: - guard let uuid = buffer.readString(length: buffer.readableBytes).flatMap({ UUID(uuidString: $0) }) else { + case (.binary, .varchar), + (.binary, .text), + (.text, .uuid), + (.text, .text), + (.text, .varchar): + guard buffer.readableBytes == 36 else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + } + + guard let uuid = buffer.readString(length: 36).flatMap({ UUID(uuidString: $0) }) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return uuid diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index 8a77bda4..f69b8d4a 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -18,10 +18,12 @@ extension PSQLFrontendMessage { // zero to indicate that there are no parameters or that the parameters all use the // default format (text); or one, in which case the specified format code is applied // to all parameters; or it can equal the actual number of parameters. - buffer.writeInteger(1, as: Int16.self) + buffer.writeInteger(Int16(self.parameters.count)) // The parameter format codes. Each must presently be zero (text) or one (binary). - buffer.writeInteger(1, as: Int16.self) + self.parameters.forEach { + buffer.writeInteger($0.psqlFormat.rawValue) + } buffer.writeInteger(Int16(self.parameters.count)) @@ -38,7 +40,7 @@ extension PSQLFrontendMessage { // result columns of the query. buffer.writeInteger(1, as: Int16.self) // The result-column format codes. Each must presently be zero (text) or one (binary). - buffer.writeInteger(1, as: Int16.self) + buffer.writeInteger(PSQLFormat.binary.rawValue, as: Int16.self) } } } diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index c467db7b..fdb495a5 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -23,9 +23,9 @@ extension PSQLBackendMessage { /// The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. var dataTypeModifier: Int32 - /// The format code being used for the field. Currently will be zero (text) or one (binary). In a RowDescription returned - /// from the statement variant of Describe, the format code is not yet known and will always be zero. - var formatCode: PSQLFormatCode + /// The format being used for the field. Currently will be text or binary. In a RowDescription returned + /// from the statement variant of Describe, the format code is not yet known and will always be text. + var format: PSQLFormat } static func decode(from buffer: inout ByteBuffer) throws -> Self { @@ -53,8 +53,8 @@ extension PSQLBackendMessage { let dataTypeModifier = buffer.readInteger(as: Int32.self)! let formatCodeInt16 = buffer.readInteger(as: Int16.self)! - guard let formatCode = PSQLFormatCode(rawValue: formatCodeInt16) else { - throw PartialDecodingError.valueNotRawRepresentable(value: formatCodeInt16, asType: PSQLFormatCode.self) + guard let format = PSQLFormat(rawValue: formatCodeInt16) else { + throw PartialDecodingError.valueNotRawRepresentable(value: formatCodeInt16, asType: PSQLFormat.self) } let field = Column( @@ -64,7 +64,7 @@ extension PSQLBackendMessage { dataType: dataType, dataTypeSize: dataTypeSize, dataTypeModifier: dataTypeModifier, - formatCode: formatCode) + format: format) result.append(field) } diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index dd15511d..32e30e49 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -3,6 +3,9 @@ protocol PSQLEncodable { /// identifies the data type that we will encode into `byteBuffer` in `encode` var psqlType: PSQLDataType { get } + /// identifies the postgres format that is used to encode the value into `byteBuffer` in `encode` + var psqlFormat: PSQLFormat { get } + /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting /// the byte count. This method is called from the default `encodeRaw` implementation. func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws @@ -16,8 +19,18 @@ protocol PSQLEncodable { /// A type that can decode itself from a postgres wire binary representation. protocol PSQLDecodable { - /// decode an entity from the `byteBuffer` in postgres binary format - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> Self + /// Decode an entity from the `byteBuffer` in postgres wire format + /// + /// - Parameters: + /// - byteBuffer: A `ByteBuffer` to decode. The byteBuffer is sliced in such a way that it is expected + /// that the complete buffer is consumed for decoding + /// - type: The postgres data type. Depending on this type the `byteBuffer`'s bytes need to be interpreted + /// in different ways. + /// - format: The postgres wire format. Can be `.text` or `.binary` + /// - context: A `PSQLDecodingContext` providing context for decoding. This includes a `JSONDecoder` + /// to use when decoding json and metadata to create better errors. + /// - Returns: A decoded object + static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self } /// A type that can be encoded into and decoded from a postgres binary format diff --git a/Sources/PostgresNIO/New/PSQLData.swift b/Sources/PostgresNIO/New/PSQLData.swift index b31f4faf..522ae585 100644 --- a/Sources/PostgresNIO/New/PSQLData.swift +++ b/Sources/PostgresNIO/New/PSQLData.swift @@ -1,8 +1,9 @@ -/// The format code being used for the field. -/// Currently will be zero (text) or one (binary). -/// In a RowDescription returned from the statement variant of Describe, -/// the format code is not yet known and will always be zero. -enum PSQLFormatCode: Int16 { +/// The format the postgres types are encoded in on the wire. +/// +/// Currently there a two wire formats supported: +/// - text +/// - binary +enum PSQLFormat: Int16 { case text = 0 case binary = 1 } @@ -11,11 +12,13 @@ struct PSQLData: Equatable { @usableFromInline var bytes: ByteBuffer? @usableFromInline var dataType: PSQLDataType + @usableFromInline var format: PSQLFormat /// use this only for testing - init(bytes: ByteBuffer?, dataType: PSQLDataType) { + init(bytes: ByteBuffer?, dataType: PSQLDataType, format: PSQLFormat) { self.bytes = bytes self.dataType = dataType + self.format = format } @inlinable @@ -29,7 +32,7 @@ struct PSQLData: Equatable { case .none: throw PSQLCastingError.missingData(targetType: type, type: self.dataType, context: context) case .some(var buffer): - return try T.decode(from: &buffer, type: self.dataType, context: context) + return try T.decode(from: &buffer, type: self.dataType, format: self.format, context: context) } } @@ -39,7 +42,7 @@ struct PSQLData: Equatable { case .none: return nil case .some(var buffer): - return try T.decode(from: &buffer, type: self.dataType, context: context) + return try T.decode(from: &buffer, type: self.dataType, format: self.format, context: context) } } } diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 23ddb4db..db854864 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -30,6 +30,10 @@ extension PostgresData: PSQLEncodable { PSQLDataType(Int32(self.type.rawValue)) } + var psqlFormat: PSQLFormat { + .binary + } + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") } @@ -47,7 +51,7 @@ extension PostgresData: PSQLEncodable { } extension PostgresData: PSQLDecodable { - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, context: PSQLDecodingContext) throws -> PostgresData { + static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> PostgresData { let myBuffer = byteBuffer.readSlice(length: byteBuffer.readableBytes)! return PostgresData(type: PostgresDataType(UInt32(type.rawValue)), typeModifier: nil, formatCode: .binary, value: myBuffer) @@ -97,7 +101,7 @@ extension PSQLError { } extension PostgresFormatCode { - init(psqlFormatCode: PSQLFormatCode) { + init(psqlFormatCode: PSQLFormat) { switch psqlFormatCode { case .binary: self = .binary diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index ed421f58..2d9767c9 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -262,6 +262,28 @@ final class IntegrationTests: XCTestCase { XCTAssertNil(try rows?.next().wait()) } + func testDecodeUUID() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PSQLConnection? + XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var rows: PSQLRows? + XCTAssertNoThrow(rows = try conn?.query(""" + SELECT '2c68f645-9ca6-468b-b193-ee97f241c2f8'::UUID as uuid + """, logger: .psqlTest).wait()) + + var row: PSQLRows.Row? + XCTAssertNoThrow(row = try rows?.next().wait()) + + XCTAssertEqual(try row?.decode(column: "uuid", as: UUID.self), UUID(uuidString: "2c68f645-9ca6-468b-b193-ee97f241c2f8")) + + XCTAssertNil(try rows?.next().wait()) + } + func testRoundTripJSONB() { struct Object: Codable, PSQLCodable { let foo: Int diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 08005156..32766aab 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -34,19 +34,26 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) - let columns: [PSQLBackendMessage.RowDescription.Column] = [ - .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, formatCode: .text) + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [PSQLBackendMessage.RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) ] + let expected: [PSQLBackendMessage.RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } - XCTAssertEqual(state.rowDescriptionReceived(.init(columns: columns)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: columns)) + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) let rowContent = ByteBuffer(string: "test") XCTAssertEqual(state.dataRowReceived(.init(columns: [rowContent])), .wait) XCTAssertEqual(state.readEventCaught(), .wait) let rowPromise = EmbeddedEventLoop().makePromise(of: StateMachineStreamNextResult.self) rowPromise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - XCTAssertEqual(state.consumeNextQueryRow(promise: rowPromise), .forwardRow([.init(bytes: rowContent, dataType: .text)], to: rowPromise)) + XCTAssertEqual(state.consumeNextQueryRow(promise: rowPromise), .forwardRow([.init(bytes: rowContent, dataType: .text, format: .binary)], to: rowPromise)) XCTAssertEqual(state.commandCompletedReceived("SELECT 1"), .forwardStreamCompletedToCurrentQuery(CircularBuffer(), commandTag: "SELECT 1", read: true)) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 50870b15..d822e687 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -20,7 +20,7 @@ class PrepareStatementStateMachineTests: XCTestCase { XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) let columns: [PSQLBackendMessage.RowDescription.Column] = [ - .init(name: "id", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: -1, formatCode: .binary) + .init(name: "id", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: -1, format: .binary) ] XCTAssertEqual(state.rowDescriptionReceived(.init(columns: columns)), diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index 243102e9..e771b080 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -61,7 +61,7 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() XCTAssertNoThrow(try values.encode(into: &buffer, context: .forTests())) - let data = PSQLData(bytes: buffer, dataType: .textArray) + let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) var result: [String]? XCTAssertNoThrow(result = try data.decode(as: [String].self, context: .forTests())) @@ -73,7 +73,7 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() XCTAssertNoThrow(try values.encode(into: &buffer, context: .forTests())) - let data = PSQLData(bytes: buffer, dataType: .textArray) + let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) var result: [String]? XCTAssertNoThrow(result = try data.decode(as: [String].self, context: .forTests())) @@ -85,7 +85,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(2)) // invalid value buffer.writeInteger(Int32(0)) buffer.writeInteger(String.psqlArrayElementType.rawValue) - let data = PSQLData(bytes: buffer, dataType: .textArray) + let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -97,7 +97,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(0)) // is empty buffer.writeInteger(Int32(1)) // invalid value, must always be 0 buffer.writeInteger(String.psqlArrayElementType.rawValue) - let data = PSQLData(bytes: buffer, dataType: .textArray) + let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -108,7 +108,7 @@ class Array_PSQLCodableTests: XCTestCase { let value: Int64 = 1 << 32 var buffer = ByteBuffer() value.encode(into: &buffer, context: .forTests()) - let data = PSQLData(bytes: buffer, dataType: .textArray) + let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -122,7 +122,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(String.psqlArrayElementType.rawValue) buffer.writeInteger(Int32(-123)) // expected element count buffer.writeInteger(Int32(1)) // dimensions... must be one - let data = PSQLData(bytes: buffer, dataType: .textArray) + let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -136,7 +136,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(String.psqlArrayElementType.rawValue) buffer.writeInteger(Int32(1)) // expected element count buffer.writeInteger(Int32(2)) // dimensions... must be one - let data = PSQLData(bytes: buffer, dataType: .textArray) + let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -151,7 +151,7 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // expected element count unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 - let data = PSQLData(bytes: unexpectedEndInElementLengthBuffer, dataType: .textArray) + let data = PSQLData(bytes: unexpectedEndInElementLengthBuffer, dataType: .textArray, format: .binary) XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -165,7 +165,7 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 unexpectedEndInElementBuffer.writeString("Hello World") // only 11 bytes, 12 needed! - let unexpectedEndInElementData = PSQLData(bytes: unexpectedEndInElementBuffer, dataType: .textArray) + let unexpectedEndInElementData = PSQLData(bytes: unexpectedEndInElementBuffer, dataType: .textArray, format: .binary) XCTAssertThrowsError(try unexpectedEndInElementData.decode(as: [String].self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index 8e2b0e54..00ec1bf5 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -3,50 +3,90 @@ import XCTest class Bool_PSQLCodableTests: XCTestCase { - func testTrueRoundTrip() { + // MARK: - Binary + + func testBinaryTrueRoundTrip() { let value = true var buffer = ByteBuffer() value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .bool) + XCTAssertEqual(value.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) - let data = PSQLData(bytes: buffer, dataType: .bool) + let data = PSQLData(bytes: buffer, dataType: .bool, format: .binary) var result: Bool? XCTAssertNoThrow(result = try data.decode(as: Bool.self, context: .forTests())) XCTAssertEqual(value, result) } - func testFalseRoundTrip() { + func testBinaryFalseRoundTrip() { let value = false var buffer = ByteBuffer() value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .bool) + XCTAssertEqual(value.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 0) - let data = PSQLData(bytes: buffer, dataType: .bool) + let data = PSQLData(bytes: buffer, dataType: .bool, format: .binary) var result: Bool? XCTAssertNoThrow(result = try data.decode(as: Bool.self, context: .forTests())) XCTAssertEqual(value, result) } - - func testDecodeBoolInvalidLength() { + + func testBinaryDecodeBoolInvalidLength() { var buffer = ByteBuffer() buffer.writeInteger(Int64(1)) - let data = PSQLData(bytes: buffer, dataType: .bool) + let data = PSQLData(bytes: buffer, dataType: .bool, format: .binary) XCTAssertThrowsError(try data.decode(as: Bool.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) } } - func testDecodeBoolInvalidValue() { + func testBinaryDecodeBoolInvalidValue() { + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(13)) + let data = PSQLData(bytes: buffer, dataType: .bool, format: .binary) + + XCTAssertThrowsError(try data.decode(as: Bool.self, context: .forTests())) { error in + XCTAssert(error is PSQLCastingError) + } + } + + // MARK: - Text + + func testTextTrueDecode() { + let value = true + + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(ascii: "t")) + let data = PSQLData(bytes: buffer, dataType: .bool, format: .text) + + var result: Bool? + XCTAssertNoThrow(result = try data.decode(as: Bool.self, context: .forTests())) + XCTAssertEqual(value, result) + } + + func testTextFalseDecode() { + let value = false + + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(ascii: "f")) + let data = PSQLData(bytes: buffer, dataType: .bool, format: .text) + + var result: Bool? + XCTAssertNoThrow(result = try data.decode(as: Bool.self, context: .forTests())) + XCTAssertEqual(value, result) + } + + func testTextDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - let data = PSQLData(bytes: buffer, dataType: .bool) + let data = PSQLData(bytes: buffer, dataType: .bool, format: .text) XCTAssertThrowsError(try data.decode(as: Bool.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index a57676a4..914efbff 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -9,7 +9,7 @@ class Bytes_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() data.encode(into: &buffer, context: .forTests()) XCTAssertEqual(data.psqlType, .bytea) - let psqlData = PSQLData(bytes: buffer, dataType: .bytea) + let psqlData = PSQLData(bytes: buffer, dataType: .bytea, format: .binary) var result: Data? XCTAssertNoThrow(result = try psqlData.decode(as: Data.self, context: .forTests())) @@ -22,7 +22,7 @@ class Bytes_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() bytes.encode(into: &buffer, context: .forTests()) XCTAssertEqual(bytes.psqlType, .bytea) - let psqlData = PSQLData(bytes: buffer, dataType: .bytea) + let psqlData = PSQLData(bytes: buffer, dataType: .bytea, format: .binary) var result: ByteBuffer? XCTAssertNoThrow(result = try psqlData.decode(as: ByteBuffer.self, context: .forTests())) diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 04f07d60..38f6dd87 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -10,7 +10,7 @@ class Date_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .timestamptz) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .timestamptz) + let data = PSQLData(bytes: buffer, dataType: .timestamptz, format: .binary) var result: Date? XCTAssertNoThrow(result = try data.decode(as: Date.self, context: .forTests())) @@ -20,7 +20,7 @@ class Date_PSQLCodableTests: XCTestCase { func testDecodeRandomDate() { var buffer = ByteBuffer() buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) - let data = PSQLData(bytes: buffer, dataType: .timestamptz) + let data = PSQLData(bytes: buffer, dataType: .timestamptz, format: .binary) var result: Date? XCTAssertNoThrow(result = try data.decode(as: Date.self, context: .forTests())) @@ -31,7 +31,7 @@ class Date_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) - let data = PSQLData(bytes: buffer, dataType: .timestamptz) + let data = PSQLData(bytes: buffer, dataType: .timestamptz, format: .binary) XCTAssertThrowsError(try data.decode(as: Date.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -41,7 +41,7 @@ class Date_PSQLCodableTests: XCTestCase { func testDecodeDate() { var firstDateBuffer = ByteBuffer() firstDateBuffer.writeInteger(Int32.min) - let firstDateData = PSQLData(bytes: firstDateBuffer, dataType: .date) + let firstDateData = PSQLData(bytes: firstDateBuffer, dataType: .date, format: .binary) var firstDate: Date? XCTAssertNoThrow(firstDate = try firstDateData.decode(as: Date.self, context: .forTests())) @@ -49,7 +49,7 @@ class Date_PSQLCodableTests: XCTestCase { var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) - let lastDateData = PSQLData(bytes: lastDateBuffer, dataType: .date) + let lastDateData = PSQLData(bytes: lastDateBuffer, dataType: .date, format: .binary) var lastDate: Date? XCTAssertNoThrow(lastDate = try lastDateData.decode(as: Date.self, context: .forTests())) @@ -59,7 +59,7 @@ class Date_PSQLCodableTests: XCTestCase { func testDecodeDateFromTimestamp() { var firstDateBuffer = ByteBuffer() firstDateBuffer.writeInteger(Int32.min) - let firstDateData = PSQLData(bytes: firstDateBuffer, dataType: .date) + let firstDateData = PSQLData(bytes: firstDateBuffer, dataType: .date, format: .binary) var firstDate: Date? XCTAssertNoThrow(firstDate = try firstDateData.decode(as: Date.self, context: .forTests())) @@ -67,7 +67,7 @@ class Date_PSQLCodableTests: XCTestCase { var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) - let lastDateData = PSQLData(bytes: lastDateBuffer, dataType: .date) + let lastDateData = PSQLData(bytes: lastDateBuffer, dataType: .date, format: .binary) var lastDate: Date? XCTAssertNoThrow(lastDate = try lastDateData.decode(as: Date.self, context: .forTests())) @@ -77,7 +77,7 @@ class Date_PSQLCodableTests: XCTestCase { func testDecodeDateFailsWithToMuchData() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - let data = PSQLData(bytes: buffer, dataType: .date) + let data = PSQLData(bytes: buffer, dataType: .date, format: .binary) XCTAssertThrowsError(try data.decode(as: Date.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -87,7 +87,7 @@ class Date_PSQLCodableTests: XCTestCase { func testDecodeDateFailsWithWrongDataType() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - let data = PSQLData(bytes: buffer, dataType: .int8) + let data = PSQLData(bytes: buffer, dataType: .int8, format: .binary) XCTAssertThrowsError(try data.decode(as: Date.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) diff --git a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift index 19fb3a84..143f907c 100644 --- a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift @@ -11,7 +11,7 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .float8) + let data = PSQLData(bytes: buffer, dataType: .float8, format: .binary) var result: Double? XCTAssertNoThrow(result = try data.decode(as: Double.self, context: .forTests())) @@ -27,7 +27,7 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float4) XCTAssertEqual(buffer.readableBytes, 4) - let data = PSQLData(bytes: buffer, dataType: .float4) + let data = PSQLData(bytes: buffer, dataType: .float4, format: .binary) var result: Float? XCTAssertNoThrow(result = try data.decode(as: Float.self, context: .forTests())) @@ -42,7 +42,7 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .float8) + let data = PSQLData(bytes: buffer, dataType: .float8, format: .binary) var result: Double? XCTAssertNoThrow(result = try data.decode(as: Double.self, context: .forTests())) @@ -56,7 +56,7 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .float8) + let data = PSQLData(bytes: buffer, dataType: .float8, format: .binary) var result: Double? XCTAssertNoThrow(result = try data.decode(as: Double.self, context: .forTests())) @@ -71,7 +71,7 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float4) XCTAssertEqual(buffer.readableBytes, 4) - let data = PSQLData(bytes: buffer, dataType: .float4) + let data = PSQLData(bytes: buffer, dataType: .float4, format: .binary) var result: Double? XCTAssertNoThrow(result = try data.decode(as: Double.self, context: .forTests())) @@ -87,7 +87,7 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .float8) + let data = PSQLData(bytes: buffer, dataType: .float8, format: .binary) var result: Float? XCTAssertNoThrow(result = try data.decode(as: Float.self, context: .forTests())) @@ -100,8 +100,8 @@ class Float_PSQLCodableTests: XCTestCase { eightByteBuffer.writeInteger(Int64(0)) var fourByteBuffer = ByteBuffer() fourByteBuffer.writeInteger(Int32(0)) - let toLongData = PSQLData(bytes: eightByteBuffer, dataType: .float4) - let toShortData = PSQLData(bytes: fourByteBuffer, dataType: .float8) + let toLongData = PSQLData(bytes: eightByteBuffer, dataType: .float4, format: .binary) + let toShortData = PSQLData(bytes: fourByteBuffer, dataType: .float8, format: .binary) XCTAssertThrowsError(try toLongData.decode(as: Double.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) @@ -123,7 +123,7 @@ class Float_PSQLCodableTests: XCTestCase { func testDecodeFailureInvalidType() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - let data = PSQLData(bytes: buffer, dataType: .int8) + let data = PSQLData(bytes: buffer, dataType: .int8, format: .binary) XCTAssertThrowsError(try data.decode(as: Double.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 12219226..883a3dfc 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -20,7 +20,7 @@ class JSON_PSQLCodableTests: XCTestCase { // verify jsonb prefix byte XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) - let data = PSQLData(bytes: buffer, dataType: .jsonb) + let data = PSQLData(bytes: buffer, dataType: .jsonb, format: .binary) var result: Hello? XCTAssertNoThrow(result = try data.decode(as: Hello.self, context: .forTests())) XCTAssertEqual(result, hello) @@ -30,17 +30,32 @@ class JSON_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - let data = PSQLData(bytes: buffer, dataType: .json) + let data = PSQLData(bytes: buffer, dataType: .json, format: .binary) var result: Hello? XCTAssertNoThrow(result = try data.decode(as: Hello.self, context: .forTests())) XCTAssertEqual(result, Hello(name: "world")) } + func testDecodeFromJSONAsText() { + let combinations : [(PSQLFormat, PSQLDataType)] = [ + (.text, .json), (.text, .jsonb), + ] + var buffer = ByteBuffer() + buffer.writeString(#"{"hello":"world"}"#) + + for (format, dataType) in combinations { + let data = PSQLData(bytes: buffer, dataType: dataType, format: format) + var result: Hello? + XCTAssertNoThrow(result = try data.decode(as: Hello.self, context: .forTests())) + XCTAssertEqual(result, Hello(name: "world")) + } + } + func testDecodeFromJSONBWithoutVersionPrefixByte() { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - let data = PSQLData(bytes: buffer, dataType: .jsonb) + let data = PSQLData(bytes: buffer, dataType: .jsonb, format: .binary) XCTAssertThrowsError(try data.decode(as: Hello.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) } @@ -50,7 +65,7 @@ class JSON_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - let data = PSQLData(bytes: buffer, dataType: .text) + let data = PSQLData(bytes: buffer, dataType: .text, format: .binary) XCTAssertThrowsError(try data.decode(as: Hello.self, context: .forTests())) { error in XCTAssert(error is PSQLCastingError) } diff --git a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift index 14fbbae8..3900505b 100644 --- a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift @@ -9,7 +9,7 @@ class Optional_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() value?.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .text) - let data = PSQLData(bytes: buffer, dataType: .text) + let data = PSQLData(bytes: buffer, dataType: .text, format: .binary) var result: String? XCTAssertNoThrow(result = try data.decode(as: String?.self, context: .forTests())) @@ -24,7 +24,7 @@ class Optional_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 0) XCTAssertEqual(value.psqlType, .null) - let data = PSQLData(bytes: nil, dataType: .text) + let data = PSQLData(bytes: nil, dataType: .text, format: .binary) var result: String? XCTAssertNoThrow(result = try data.decode(as: String?.self, context: .forTests())) @@ -40,7 +40,7 @@ class Optional_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(try encodable.encodeRaw(into: &buffer, context: .forTests())) XCTAssertEqual(buffer.readableBytes, 20) XCTAssertEqual(buffer.readInteger(as: Int32.self), 16) - let data = PSQLData(bytes: buffer, dataType: .uuid) + let data = PSQLData(bytes: buffer, dataType: .uuid, format: .binary) var result: UUID? XCTAssertNoThrow(result = try data.decode(as: UUID?.self, context: .forTests())) @@ -57,7 +57,7 @@ class Optional_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 4) XCTAssertEqual(buffer.readInteger(as: Int32.self), -1) - let data = PSQLData(bytes: nil, dataType: .uuid) + let data = PSQLData(bytes: nil, dataType: .uuid, format: .binary) var result: UUID? XCTAssertNoThrow(result = try data.decode(as: UUID?.self, context: .forTests())) diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift index ba28220e..1e472366 100644 --- a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -17,7 +17,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(try value.encode(into: &buffer, context: .forTests())) XCTAssertEqual(value.psqlType, Int16.psqlArrayElementType) XCTAssertEqual(buffer.readableBytes, 2) - let data = PSQLData(bytes: buffer, dataType: Int16.psqlArrayElementType) + let data = PSQLData(bytes: buffer, dataType: Int16.psqlArrayElementType, format: .binary) var result: MyRawRepresentable? XCTAssertNoThrow(result = try data.decode(as: MyRawRepresentable.self, context: .forTests())) @@ -28,7 +28,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { func testDecodeInvalidRawTypeValue() { var buffer = ByteBuffer() buffer.writeInteger(Int16(4)) // out of bounds - let data = PSQLData(bytes: buffer, dataType: Int16.psqlArrayElementType) + let data = PSQLData(bytes: buffer, dataType: Int16.psqlArrayElementType, format: .binary) XCTAssertThrowsError(try data.decode(as: MyRawRepresentable.self, context: .forTests())) { error in XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) @@ -40,7 +40,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { func testDecodeInvalidUnderlyingTypeValue() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // out of bounds - let data = PSQLData(bytes: buffer, dataType: Int32.psqlArrayElementType) + let data = PSQLData(bytes: buffer, dataType: Int32.psqlArrayElementType, format: .binary) XCTAssertThrowsError(try data.decode(as: MyRawRepresentable.self, context: .forTests())) { error in XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index faa00555..c8cce8f1 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -25,7 +25,7 @@ class String_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer var result: String? - XCTAssertNoThrow(result = try String.decode(from: &loopBuffer, type: dataType, context: .forTests())) + XCTAssertNoThrow(result = try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) XCTAssertEqual(result, expected) } } @@ -36,7 +36,7 @@ class String_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer - XCTAssertThrowsError(try String.decode(from: &loopBuffer, type: dataType, context: .forTests())) { error in + XCTAssertThrowsError(try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) { error in XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) XCTAssertEqual((error as? PSQLCastingError)?.file, #file) @@ -50,7 +50,7 @@ class String_PSQLCodableTests: XCTestCase { let dataTypes: [PSQLDataType] = [.text, .varchar, .name] for dataType in dataTypes { - let data = PSQLData(bytes: nil, dataType: dataType) + let data = PSQLData(bytes: nil, dataType: dataType, format: .binary) XCTAssertThrowsError(try data.decode(as: String.self, context: .forTests())) { error in XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) XCTAssertEqual((error as? PSQLCastingError)?.file, #file) @@ -67,7 +67,7 @@ class String_PSQLCodableTests: XCTestCase { uuid.encode(into: &buffer, context: .forTests()) var decoded: String? - XCTAssertNoThrow(decoded = try String.decode(from: &buffer, type: .uuid, context: .forTests())) + XCTAssertNoThrow(decoded = try String.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) XCTAssertEqual(decoded, uuid.uuidString) } @@ -78,7 +78,7 @@ class String_PSQLCodableTests: XCTestCase { // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - XCTAssertThrowsError(try String.decode(from: &buffer, type: .uuid, context: .forTests())) { error in + XCTAssertThrowsError(try String.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) { error in XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) XCTAssertEqual((error as? PSQLCastingError)?.file, #file) diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 4d33efa5..7325955b 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -11,6 +11,7 @@ class UUID_PSQLCodableTests: XCTestCase { uuid.encode(into: &buffer, context: .forTests()) XCTAssertEqual(uuid.psqlType, .uuid) + XCTAssertEqual(uuid.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 16) var byteIterator = buffer.readableBytesView.makeIterator() @@ -32,13 +33,19 @@ class UUID_PSQLCodableTests: XCTestCase { XCTAssertEqual(byteIterator.next(), uuid.uuid.15) var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &buffer, type: .uuid, context: .forTests())) + XCTAssertNoThrow(decoded = try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) XCTAssertEqual(decoded, uuid) } } func testDecodeFromString() { - let dataTypes: [PSQLDataType] = [.varchar, .text] + let options: [(PSQLFormat, PSQLDataType)] = [ + (.binary, .text), + (.binary, .varchar), + (.text, .uuid), + (.text, .text), + (.text, .varchar), + ] for _ in 0..<100 { // use uppercase @@ -46,10 +53,10 @@ class UUID_PSQLCodableTests: XCTestCase { var lowercaseBuffer = ByteBuffer() lowercaseBuffer.writeString(uuid.uuidString.lowercased()) - for dataType in dataTypes { + for (format, dataType) in options { var loopBuffer = lowercaseBuffer var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, context: .forTests())) + XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, format: format, context: .forTests())) XCTAssertEqual(decoded, uuid) } @@ -57,10 +64,10 @@ class UUID_PSQLCodableTests: XCTestCase { var uppercaseBuffer = ByteBuffer() uppercaseBuffer.writeString(uuid.uuidString) - for dataType in dataTypes { + for (format, dataType) in options { var loopBuffer = uppercaseBuffer var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, context: .forTests())) + XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, format: format, context: .forTests())) XCTAssertEqual(decoded, uuid) } } @@ -74,7 +81,7 @@ class UUID_PSQLCodableTests: XCTestCase { // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - XCTAssertThrowsError(try UUID.decode(from: &buffer, type: .uuid, context: .forTests())) { error in + XCTAssertThrowsError(try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) { error in XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) XCTAssertEqual((error as? PSQLCastingError)?.file, #file) @@ -94,7 +101,7 @@ class UUID_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer - XCTAssertThrowsError(try UUID.decode(from: &loopBuffer, type: dataType, context: .forTests())) { error in + XCTAssertThrowsError(try UUID.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) { error in XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) XCTAssertEqual((error as? PSQLCastingError)?.file, #file) @@ -112,7 +119,7 @@ class UUID_PSQLCodableTests: XCTestCase { let dataTypes: [PSQLDataType] = [.bool, .int8, .int2, .int4Array] for dataType in dataTypes { - let data = PSQLData(bytes: buffer, dataType: dataType) + let data = PSQLData(bytes: buffer, dataType: dataType, format: .binary) XCTAssertThrowsError(try data.decode(as: UUID.self, context: .forTests())) { error in XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) @@ -124,4 +131,3 @@ class UUID_PSQLCodableTests: XCTestCase { } } } - diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 55598fad..de6ace78 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -10,14 +10,15 @@ class BindTests: XCTestCase { let message = PSQLFrontendMessage.bind(bind) XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) - XCTAssertEqual(byteBuffer.readableBytes, 35) + XCTAssertEqual(byteBuffer.readableBytes, 37) XCTAssertEqual(PSQLFrontendMessage.ID.bind.byte, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 34) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - // all parameters have the same format: therefore one format byte is next + // the number of parameters + XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) + // all (two) parameters have the same format (binary) XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) - // all parameters have the same format (binary) XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) // read number of parameters diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index 3ce7fb12..bbcb103d 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -7,8 +7,8 @@ class RowDescriptionTests: XCTestCase { func testDecode() { let columns: [PSQLBackendMessage.RowDescription.Column] = [ - .init(name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, formatCode: .binary), - .init(name: "Second", tableOID: 123, columnAttributeNumber: 456, dataType: .uuidArray, dataTypeSize: 567, dataTypeModifier: 123, formatCode: .text), + .init(name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary), + .init(name: "Second", tableOID: 123, columnAttributeNumber: 456, dataType: .uuidArray, dataTypeSize: 567, dataTypeModifier: 123, format: .text), ] let expected: [PSQLBackendMessage] = [ @@ -31,7 +31,7 @@ class RowDescriptionTests: XCTestCase { buffer.writeInteger(column.dataType.rawValue) buffer.writeInteger(column.dataTypeSize) buffer.writeInteger(column.dataTypeModifier) - buffer.writeInteger(column.formatCode.rawValue) + buffer.writeInteger(column.format.rawValue) } } } @@ -43,7 +43,7 @@ class RowDescriptionTests: XCTestCase { func testDecodeFailureBecauseOfMissingNullTerminationInColumnName() { let column = PSQLBackendMessage.RowDescription.Column( - name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, formatCode: .binary) + name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in @@ -54,7 +54,7 @@ class RowDescriptionTests: XCTestCase { buffer.writeInteger(column.dataType.rawValue) buffer.writeInteger(column.dataTypeSize) buffer.writeInteger(column.dataTypeModifier) - buffer.writeInteger(column.formatCode.rawValue) + buffer.writeInteger(column.format.rawValue) } XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( @@ -66,7 +66,7 @@ class RowDescriptionTests: XCTestCase { func testDecodeFailureBecauseOfMissingColumnCount() { let column = PSQLBackendMessage.RowDescription.Column( - name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, formatCode: .binary) + name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in @@ -76,7 +76,7 @@ class RowDescriptionTests: XCTestCase { buffer.writeInteger(column.dataType.rawValue) buffer.writeInteger(column.dataTypeSize) buffer.writeInteger(column.dataTypeModifier) - buffer.writeInteger(column.formatCode.rawValue) + buffer.writeInteger(column.format.rawValue) } XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( @@ -88,7 +88,7 @@ class RowDescriptionTests: XCTestCase { func testDecodeFailureBecauseInvalidFormatCode() { let column = PSQLBackendMessage.RowDescription.Column( - name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, formatCode: .binary) + name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in @@ -111,7 +111,7 @@ class RowDescriptionTests: XCTestCase { func testDecodeFailureBecauseNegativeColumnCount() { let column = PSQLBackendMessage.RowDescription.Column( - name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, formatCode: .binary) + name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in @@ -122,7 +122,7 @@ class RowDescriptionTests: XCTestCase { buffer.writeInteger(column.dataType.rawValue) buffer.writeInteger(column.dataTypeSize) buffer.writeInteger(column.dataTypeModifier) - buffer.writeInteger(column.formatCode.rawValue) + buffer.writeInteger(column.format.rawValue) } XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( diff --git a/Tests/PostgresNIOTests/New/PSQLDataTests.swift b/Tests/PostgresNIOTests/New/PSQLDataTests.swift index e6e2a8d2..55873310 100644 --- a/Tests/PostgresNIOTests/New/PSQLDataTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLDataTests.swift @@ -6,7 +6,7 @@ class PSQLDataTests: XCTestCase { func testStringDecoding() { let emptyBuffer: ByteBuffer? = nil - let data = PSQLData(bytes: emptyBuffer, dataType: .text) + let data = PSQLData(bytes: emptyBuffer, dataType: .text, format: .binary) var emptyResult: String? XCTAssertNoThrow(emptyResult = try data.decodeIfPresent(as: String.self, context: .forTests())) From 1d379346c731afc9ec76a3c9c58ae02dfa5c1369 Mon Sep 17 00:00:00 2001 From: Greg Ennis Date: Wed, 28 Jul 2021 12:24:11 -0400 Subject: [PATCH 011/292] Support AWS Redshift queries - SELECT Metadata without row number (#167) - Allow only one part in metadata response (Support AWS Redshift) Co-authored-by: Fabian Fett --- .../PostgresNIO/PostgresDatabase+Query.swift | 8 +++++++- .../New/PSQLMetadataTests.swift | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 Tests/PostgresNIOTests/New/PSQLMetadataTests.swift diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index ce89bcb5..f9823ae6 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -74,7 +74,13 @@ public struct PostgresQueryMetadata { self.command = .init(parts[0]) self.oid = Int(parts[1]) self.rows = Int(parts[2]) - case "DELETE", "UPDATE", "SELECT", "MOVE", "FETCH", "COPY": + case "SELECT" where parts.count == 1: + // AWS Redshift does not return the actual row count as defined in the postgres wire spec for SELECT: + // https://www.postgresql.org/docs/13/protocol-message-formats.html in section `CommandComplete` + self.command = "SELECT" + self.oid = nil + self.rows = nil + case "SELECT", "DELETE", "UPDATE", "MOVE", "FETCH", "COPY": // rows guard parts.count == 2 else { return nil diff --git a/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift b/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift new file mode 100644 index 00000000..e190c740 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift @@ -0,0 +1,18 @@ +import NIO +import XCTest +@testable import PostgresNIO + +class PSQLMetadataTests: XCTestCase { + func testSelect() { + XCTAssertEqual(100, PostgresQueryMetadata(string: "SELECT 100")?.rows) + XCTAssertNotNil(PostgresQueryMetadata(string: "SELECT")) + XCTAssertNil(PostgresQueryMetadata(string: "SELECT")?.rows) + XCTAssertNil(PostgresQueryMetadata(string: "SELECT 100 100")) + } + + func testUpdate() { + XCTAssertEqual(100, PostgresQueryMetadata(string: "UPDATE 100")?.rows) + XCTAssertNil(PostgresQueryMetadata(string: "UPDATE")) + XCTAssertNil(PostgresQueryMetadata(string: "UPDATE 100 100")) + } +} From 6f6ec7b40aa6881697333d79730cfc99bea79acc Mon Sep 17 00:00:00 2001 From: David Nadoba Date: Wed, 18 Aug 2021 10:07:21 +0200 Subject: [PATCH 012/292] Splits CI Jobs Into Unit And Integration Tests On Linux (#169) * split testing job into unit and integration test jobs * execute workflow on push for now. Will be remove before opening the PR * use correct name for nightly swift versions * use correct format for specifying docker image * remove amazonlinux2 * only run on pull_request --- .github/workflows/test.yml | 40 ++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2bf7e19d..28bd3784 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,6 +13,7 @@ jobs: swiftver: - 5.2 - 5.3 + - 5.4 dbimage: - postgres:13 - postgres:12 @@ -55,9 +56,35 @@ jobs: POSTGRES_HOSTNAME: psql-a POSTGRES_HOSTNAME_A: psql-a POSTGRES_HOSTNAME_B: psql-b + + # Run unit tests on Linux Swift runners on + linux-unit-tests: + strategy: + fail-fast: false + matrix: + swiftver: + - swift:5.2 + - swift:5.3 + - swift:5.4 + - swiftlang/swift:nightly-5.5 + - swiftlang/swift:nightly-main + swiftos: + #- xenial + #- bionic + - focal + #- centos7 + #- centos8 + #- amazonlinux2 + container: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v2 + - name: Run tests with Thread Sanitizer + run: swift test --enable-test-discovery --sanitize=thread --filter=^PostgresNIOTests - # Run package tests on Linux Swift runners against supported PSQL versions - linux: + # Run integration tests on Linux Swift runners against supported PSQL versions + linux-integration-tests: strategy: fail-fast: false matrix: @@ -70,17 +97,14 @@ jobs: - md5 - scram-sha-256 swiftver: - - swift:5.2 - - swift:5.3 - #- swiftlang/swift:nightly-5.3 - #- swiftlang/swift:nightly-master + - swift:5.4 swiftos: #- xenial #- bionic - focal #- centos7 #- centos8 - - amazonlinux2 + #- amazonlinux2 container: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} runs-on: ubuntu-latest services: @@ -96,7 +120,7 @@ jobs: - name: Check out code uses: actions/checkout@v2 - name: Run tests with Thread Sanitizer - run: swift test --enable-test-discovery --sanitize=thread + run: swift test --enable-test-discovery --sanitize=thread --filter=^IntegrationTests env: POSTGRES_HOSTNAME: psql POSTGRES_USER: vapor_username From 08c0dc590f4e149857d99dc91be4da342444dece Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 18 Aug 2021 10:32:07 +0200 Subject: [PATCH 013/292] Fix NIO imports (#171) ### Motivation - SwiftNIO 2.32.0 introduces explicit modules for its Core and Posix. - To write great platform independent code in most parts we should only ever import `NIOCore` ### Changes - This pr explicitly imports NIOCore, NIOPosix and NIOEmbedded where needed. --- Package.swift | 7 +++++-- .../Connection/PostgresConnection+Authenticate.swift | 3 ++- .../Connection/PostgresConnection+Connect.swift | 4 +++- .../Connection/PostgresConnection+Database.swift | 1 + .../Connection/PostgresConnection+Notifications.swift | 2 +- Sources/PostgresNIO/Connection/PostgresConnection.swift | 2 +- .../Connection/PostgresDatabase+PreparedQuery.swift | 3 ++- Sources/PostgresNIO/Data/PostgresData+Array.swift | 2 ++ Sources/PostgresNIO/Data/PostgresData+Bool.swift | 2 ++ Sources/PostgresNIO/Data/PostgresData+Bytes.swift | 1 + Sources/PostgresNIO/Data/PostgresData+Date.swift | 3 ++- Sources/PostgresNIO/Data/PostgresData+Double.swift | 2 ++ Sources/PostgresNIO/Data/PostgresData+JSON.swift | 3 ++- Sources/PostgresNIO/Data/PostgresData+JSONB.swift | 3 ++- Sources/PostgresNIO/Data/PostgresData+Numeric.swift | 1 + Sources/PostgresNIO/Data/PostgresData+String.swift | 2 ++ Sources/PostgresNIO/Data/PostgresData+UUID.swift | 1 + Sources/PostgresNIO/Data/PostgresData.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+0.swift | 2 ++ .../Message/PostgresMessage+Authentication.swift | 2 +- .../Message/PostgresMessage+BackendKeyData.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Bind.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Close.swift | 2 +- .../Message/PostgresMessage+CommandComplete.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Describe.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Error.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Execute.swift | 2 +- .../PostgresNIO/Message/PostgresMessage+Identifier.swift | 2 +- .../Message/PostgresMessage+NotificationResponse.swift | 2 +- .../Message/PostgresMessage+ParameterDescription.swift | 2 +- .../Message/PostgresMessage+ParameterStatus.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Parse.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Password.swift | 2 +- .../Message/PostgresMessage+ReadyForQuery.swift | 2 +- .../Message/PostgresMessage+RowDescription.swift | 2 +- .../PostgresNIO/Message/PostgresMessage+SASLResponse.swift | 2 +- .../PostgresNIO/Message/PostgresMessage+SSLRequest.swift | 2 +- .../PostgresNIO/Message/PostgresMessage+SimpleQuery.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Startup.swift | 2 +- Sources/PostgresNIO/Message/PostgresMessage+Sync.swift | 2 +- .../PostgresNIO/Message/PostgresMessage+Terminate.swift | 2 ++ Sources/PostgresNIO/Message/PostgresMessageDecoder.swift | 3 ++- Sources/PostgresNIO/Message/PostgresMessageEncoder.swift | 3 ++- Sources/PostgresNIO/Message/PostgresMessageType.swift | 2 ++ .../AuthenticationStateMachine.swift | 2 +- .../Connection State Machine/ConnectionStateMachine.swift | 1 + .../ExtendedQueryStateMachine.swift | 1 + Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift | 2 +- Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift | 2 ++ Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift | 1 + Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift | 1 + Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift | 2 ++ Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift | 2 ++ Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift | 1 + Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift | 2 ++ .../New/Data/RawRepresentable+PSQLCodable.swift | 2 ++ Sources/PostgresNIO/New/Data/String+PSQLCodable.swift | 1 + Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift | 1 + Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift | 2 +- Sources/PostgresNIO/New/Messages/Authentication.swift | 2 +- Sources/PostgresNIO/New/Messages/BackendKeyData.swift | 2 ++ Sources/PostgresNIO/New/Messages/Bind.swift | 2 ++ Sources/PostgresNIO/New/Messages/Cancel.swift | 2 ++ Sources/PostgresNIO/New/Messages/Close.swift | 2 ++ Sources/PostgresNIO/New/Messages/DataRow.swift | 2 +- Sources/PostgresNIO/New/Messages/Describe.swift | 2 ++ Sources/PostgresNIO/New/Messages/ErrorResponse.swift | 2 ++ Sources/PostgresNIO/New/Messages/Execute.swift | 2 ++ .../PostgresNIO/New/Messages/NotificationResponse.swift | 2 +- .../PostgresNIO/New/Messages/ParameterDescription.swift | 2 ++ Sources/PostgresNIO/New/Messages/ParameterStatus.swift | 2 ++ Sources/PostgresNIO/New/Messages/Parse.swift | 2 ++ Sources/PostgresNIO/New/Messages/Password.swift | 2 ++ Sources/PostgresNIO/New/Messages/ReadyForQuery.swift | 2 ++ Sources/PostgresNIO/New/Messages/RowDescription.swift | 2 ++ Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift | 2 ++ Sources/PostgresNIO/New/Messages/SASLResponse.swift | 2 ++ Sources/PostgresNIO/New/Messages/SSLRequest.swift | 2 +- Sources/PostgresNIO/New/Messages/Startup.swift | 2 +- Sources/PostgresNIO/New/PSQL+JSON.swift | 3 ++- Sources/PostgresNIO/New/PSQLBackendMessage.swift | 3 ++- Sources/PostgresNIO/New/PSQLChannelHandler.swift | 2 +- Sources/PostgresNIO/New/PSQLCodable.swift | 2 ++ Sources/PostgresNIO/New/PSQLConnection.swift | 3 ++- Sources/PostgresNIO/New/PSQLData.swift | 2 ++ Sources/PostgresNIO/New/PSQLError.swift | 2 +- Sources/PostgresNIO/New/PSQLEventsHandler.swift | 2 ++ Sources/PostgresNIO/New/PSQLFrontendMessage.swift | 2 +- Sources/PostgresNIO/New/PSQLRows.swift | 2 +- Sources/PostgresNIO/New/PSQLTask.swift | 3 +++ Sources/PostgresNIO/Postgres+PSQLCompat.swift | 2 ++ Sources/PostgresNIO/PostgresDatabase+Query.swift | 2 +- Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift | 2 +- Sources/PostgresNIO/PostgresDatabase.swift | 3 +++ Sources/PostgresNIO/Utilities/Exports.swift | 1 + Sources/PostgresNIO/Utilities/NIOUtils.swift | 2 +- Tests/IntegrationTests/PSQLIntegrationTests.swift | 4 +++- Tests/IntegrationTests/PerformanceTests.swift | 4 +++- Tests/IntegrationTests/PostgresNIOTests.swift | 2 ++ Tests/IntegrationTests/Utilities.swift | 3 ++- .../AuthenticationStateMachineTests.swift | 1 + .../ConnectionStateMachineTests.swift | 4 +++- .../ExtendedQueryStateMachineTests.swift | 3 +++ .../PrepareStatementStateMachineTests.swift | 1 + .../PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift | 1 + .../PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift | 1 + .../PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift | 1 + .../PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift | 1 + .../PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift | 1 + .../PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift | 1 + .../New/Data/Optional+PSQLCodableTests.swift | 1 + .../New/Data/RawRepresentable+PSQLCodableTests.swift | 1 + .../New/Data/String+PSQLCodableTests.swift | 1 + .../PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift | 1 + .../PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift | 2 +- .../New/Extensions/ConnectionAction+TestUtils.swift | 1 + Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift | 2 -- .../New/Extensions/PSQLFrontendMessage+Equatable.swift | 3 ++- .../New/Messages/AuthenticationTests.swift | 4 ++-- .../New/Messages/BackendKeyDataTests.swift | 4 ++-- Tests/PostgresNIOTests/New/Messages/BindTests.swift | 1 + Tests/PostgresNIOTests/New/Messages/CancelTests.swift | 1 + Tests/PostgresNIOTests/New/Messages/CloseTests.swift | 1 + Tests/PostgresNIOTests/New/Messages/DataRowTests.swift | 4 ++-- Tests/PostgresNIOTests/New/Messages/DescribeTests.swift | 1 + .../PostgresNIOTests/New/Messages/ErrorResponseTests.swift | 4 ++-- Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift | 1 + .../New/Messages/NotificationResponseTests.swift | 4 ++-- .../New/Messages/ParameterDescriptionTests.swift | 4 ++-- .../New/Messages/ParameterStatusTests.swift | 4 ++-- Tests/PostgresNIOTests/New/Messages/ParseTests.swift | 3 +-- Tests/PostgresNIOTests/New/Messages/PasswordTests.swift | 1 + .../PostgresNIOTests/New/Messages/ReadyForQueryTests.swift | 4 ++-- .../New/Messages/RowDescriptionTests.swift | 4 ++-- .../New/Messages/SASLInitialResponseTests.swift | 3 +-- .../PostgresNIOTests/New/Messages/SASLResponseTests.swift | 3 +-- Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift | 1 + Tests/PostgresNIOTests/New/Messages/StartupTests.swift | 1 + Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift | 3 ++- Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift | 4 +++- Tests/PostgresNIOTests/New/PSQLConnectionTests.swift | 3 ++- Tests/PostgresNIOTests/New/PSQLDataTests.swift | 2 +- Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift | 1 + Tests/PostgresNIOTests/New/PSQLMetadataTests.swift | 2 +- 145 files changed, 219 insertions(+), 90 deletions(-) delete mode 100644 Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift diff --git a/Package.swift b/Package.swift index 86dfd2d1..1c393213 100644 --- a/Package.swift +++ b/Package.swift @@ -13,7 +13,7 @@ let package = Package( .library(name: "PostgresNIO", targets: ["PostgresNIO"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.28.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.32.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.12.0"), .package(url: "https://github.com/apple/swift-crypto.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), @@ -25,12 +25,15 @@ let package = Package( .product(name: "Logging", package: "swift-log"), .product(name: "Metrics", package: "swift-metrics"), .product(name: "NIO", package: "swift-nio"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOTLS", package: "swift-nio"), - .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "NIOFoundationCompat", package: "swift-nio"), ]), .testTarget(name: "PostgresNIOTests", dependencies: [ .target(name: "PostgresNIO"), + .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "NIOTestUtils", package: "swift-nio"), ]), .testTarget(name: "IntegrationTests", dependencies: [ diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift index c0fc299c..d58943ba 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift @@ -1,4 +1,5 @@ -import NIO +import NIOCore +import Logging extension PostgresConnection { public func authenticate( diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift index 3a1cf425..49463aa5 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift @@ -1,4 +1,6 @@ -import NIO +import NIOCore +import NIOSSL +import Logging extension PostgresConnection { public static func connect( diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift index f6f99a1c..6ee6ddbf 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift @@ -1,3 +1,4 @@ +import NIOCore import Logging import struct Foundation.Data diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift index 3ba591cf..9a21437d 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import Logging /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 2e1c8da0..c400711e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import Logging import struct Foundation.UUID diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift index 77f8be45..cf315b19 100644 --- a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift +++ b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift @@ -1,4 +1,5 @@ -import Foundation +import NIOCore +import struct Foundation.UUID extension PostgresDatabase { public func prepare(query: String) -> EventLoopFuture { diff --git a/Sources/PostgresNIO/Data/PostgresData+Array.swift b/Sources/PostgresNIO/Data/PostgresData+Array.swift index 4febed36..bbb420bc 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Array.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Array.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PostgresData { public init(array: [T]) where T: PostgresDataConvertible diff --git a/Sources/PostgresNIO/Data/PostgresData+Bool.swift b/Sources/PostgresNIO/Data/PostgresData+Bool.swift index 79c31dd8..99e0c670 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Bool.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Bool.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PostgresData { public init(bool: Bool) { var buffer = ByteBufferAllocator().buffer(capacity: 1) diff --git a/Sources/PostgresNIO/Data/PostgresData+Bytes.swift b/Sources/PostgresNIO/Data/PostgresData+Bytes.swift index 8316f61e..292c3c0a 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Bytes.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Bytes.swift @@ -1,4 +1,5 @@ import struct Foundation.Data +import NIOCore extension PostgresData { public init(bytes: Bytes) diff --git a/Sources/PostgresNIO/Data/PostgresData+Date.swift b/Sources/PostgresNIO/Data/PostgresData+Date.swift index 0afbc78f..86fa2f17 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Date.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Date.swift @@ -1,4 +1,5 @@ -import Foundation +import struct Foundation.Date +import NIOCore extension PostgresData { public init(date: Date) { diff --git a/Sources/PostgresNIO/Data/PostgresData+Double.swift b/Sources/PostgresNIO/Data/PostgresData+Double.swift index 6012cc03..7435cdaa 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Double.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Double.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PostgresData { public init(double: Double) { var buffer = ByteBufferAllocator().buffer(capacity: 0) diff --git a/Sources/PostgresNIO/Data/PostgresData+JSON.swift b/Sources/PostgresNIO/Data/PostgresData+JSON.swift index 2d439fa9..519b721d 100644 --- a/Sources/PostgresNIO/Data/PostgresData+JSON.swift +++ b/Sources/PostgresNIO/Data/PostgresData+JSON.swift @@ -1,4 +1,5 @@ -import Foundation +import struct Foundation.Data +import NIOCore extension PostgresData { public init(json jsonData: Data) { diff --git a/Sources/PostgresNIO/Data/PostgresData+JSONB.swift b/Sources/PostgresNIO/Data/PostgresData+JSONB.swift index 2d0ca078..0b374ba6 100644 --- a/Sources/PostgresNIO/Data/PostgresData+JSONB.swift +++ b/Sources/PostgresNIO/Data/PostgresData+JSONB.swift @@ -1,4 +1,5 @@ -import Foundation +import NIOCore +import struct Foundation.Data fileprivate let jsonBVersionBytes: [UInt8] = [0x01] diff --git a/Sources/PostgresNIO/Data/PostgresData+Numeric.swift b/Sources/PostgresNIO/Data/PostgresData+Numeric.swift index 96bd6a77..5e564d6d 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Numeric.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Numeric.swift @@ -1,3 +1,4 @@ +import NIOCore import struct Foundation.Decimal public struct PostgresNumeric: CustomStringConvertible, CustomDebugStringConvertible, ExpressibleByStringLiteral { diff --git a/Sources/PostgresNIO/Data/PostgresData+String.swift b/Sources/PostgresNIO/Data/PostgresData+String.swift index feb27641..79d9d428 100644 --- a/Sources/PostgresNIO/Data/PostgresData+String.swift +++ b/Sources/PostgresNIO/Data/PostgresData+String.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PostgresData { public init(string: String) { var buffer = ByteBufferAllocator().buffer(capacity: string.utf8.count) diff --git a/Sources/PostgresNIO/Data/PostgresData+UUID.swift b/Sources/PostgresNIO/Data/PostgresData+UUID.swift index 2bfc963e..148a9e66 100644 --- a/Sources/PostgresNIO/Data/PostgresData+UUID.swift +++ b/Sources/PostgresNIO/Data/PostgresData+UUID.swift @@ -1,4 +1,5 @@ import Foundation +import NIOCore extension PostgresData { public init(uuid: UUID) { diff --git a/Sources/PostgresNIO/Data/PostgresData.swift b/Sources/PostgresNIO/Data/PostgresData.swift index 86686556..916c27bd 100644 --- a/Sources/PostgresNIO/Data/PostgresData.swift +++ b/Sources/PostgresNIO/Data/PostgresData.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import Foundation public struct PostgresData: CustomStringConvertible, CustomDebugStringConvertible { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+0.swift b/Sources/PostgresNIO/Message/PostgresMessage+0.swift index d7d600a8..f33e89a3 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+0.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+0.swift @@ -1,3 +1,5 @@ +import NIOCore + /// A frontend or backend Postgres message. public struct PostgresMessage: Equatable { public var identifier: Identifier diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift b/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift index c515cd9c..44523a5c 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Authentication request returned by the server. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift b/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift index f25994d8..85c2277a 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as cancellation key data. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift b/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift index 89ef11c8..a5687c40 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Bind command. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift b/Sources/PostgresNIO/Message/PostgresMessage+Close.swift index 82389bf2..9e5dd99e 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Close.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Close Command diff --git a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift b/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift index 8ac6e706..406dc036 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Close command. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift b/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift index 2ead1b3a..e5cc3d9d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a data row. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift b/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift index 3a5ebd46..8c3bc8f5 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Describe command. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift index 9b0a18cd..51b9be7e 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// First message sent from the frontend during startup. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift b/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift index 7e4b54a4..4b8bc999 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as an Execute command. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift index 2f4d599f..3c0c3ef0 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies an incoming or outgoing postgres message. Sent as the first byte, before the message size. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift index b381bfaf..4979e354 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a notification response. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift index a3806d0f..3dfdb8e1 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a parameter description. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift b/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift index 09939bef..5e2f5881 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { public struct ParameterStatus: PostgresMessageType, CustomStringConvertible { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Parse.swift b/Sources/PostgresNIO/Message/PostgresMessage+Parse.swift index 749a6949..030076d0 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Parse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Parse.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Parse command. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Password.swift b/Sources/PostgresNIO/Message/PostgresMessage+Password.swift index f28463e7..5b2cef63 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Password.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Password.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a password response. Note that this is also used for diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift b/Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift index b05e833b..c46047dd 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message type. ReadyForQuery is sent whenever the backend is ready for a new query cycle. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift index 61aec62b..48a90c18 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a row description. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift b/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift index 724188b0..553edc2c 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// SASL ongoing challenge response message sent by the client. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift b/Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift index 9133d26a..a636f23f 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// A message asking the PostgreSQL server if SSL is supported diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift b/Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift index 80e106b5..7b1ec2f9 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a simple query. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Startup.swift b/Sources/PostgresNIO/Message/PostgresMessage+Startup.swift index 25f68772..d4d09009 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Startup.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Startup.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// First message sent from the frontend during startup. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Sync.swift b/Sources/PostgresNIO/Message/PostgresMessage+Sync.swift index 6e47cefb..37d54dd7 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Sync.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Sync.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PostgresMessage { /// Identifies the message as a Bind command. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift b/Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift index 61227fdf..5e34665a 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PostgresMessage { public struct Terminate: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { diff --git a/Sources/PostgresNIO/Message/PostgresMessageDecoder.swift b/Sources/PostgresNIO/Message/PostgresMessageDecoder.swift index 9a64e827..53ce73de 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageDecoder.swift +++ b/Sources/PostgresNIO/Message/PostgresMessageDecoder.swift @@ -1,4 +1,5 @@ -import NIO +import NIOCore +import Logging public final class PostgresMessageDecoder: ByteToMessageDecoder { /// See `ByteToMessageDecoder`. diff --git a/Sources/PostgresNIO/Message/PostgresMessageEncoder.swift b/Sources/PostgresNIO/Message/PostgresMessageEncoder.swift index 4eca9bc5..19f467a4 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageEncoder.swift +++ b/Sources/PostgresNIO/Message/PostgresMessageEncoder.swift @@ -1,4 +1,5 @@ -import NIO +import NIOCore +import Logging public final class PostgresMessageEncoder: MessageToByteEncoder { /// See `MessageToByteEncoder`. diff --git a/Sources/PostgresNIO/Message/PostgresMessageType.swift b/Sources/PostgresNIO/Message/PostgresMessageType.swift index 9a69fa30..604da4b9 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageType.swift +++ b/Sources/PostgresNIO/Message/PostgresMessageType.swift @@ -1,3 +1,5 @@ +import NIOCore + public protocol PostgresMessageType { static var identifier: PostgresMessage.Identifier { get } static func parse(from buffer: inout ByteBuffer) throws -> Self diff --git a/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift index ffcf3330..5848288d 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore struct AuthenticationStateMachine { diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 1c3629b7..49168e97 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1,3 +1,4 @@ +import NIOCore struct ConnectionStateMachine { diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index faf15626..f1ae086f 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -1,3 +1,4 @@ +import NIOCore struct ExtendedQueryStateMachine { diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index 607a0d5b..d2211885 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import struct Foundation.UUID /// A type, of which arrays can be encoded into and decoded from a postgres binary format diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift index f67031c0..9ab2cc0f 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift @@ -1,3 +1,5 @@ +import NIOCore + extension Bool: PSQLCodable { var psqlType: PSQLDataType { .bool diff --git a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift index 3745e704..be8b2dd8 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift @@ -1,4 +1,5 @@ import struct Foundation.Data +import NIOCore import NIOFoundationCompat extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift index a0e9efff..f78a915b 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift @@ -1,3 +1,4 @@ +import NIOCore import struct Foundation.Date extension Date: PSQLCodable { diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index be9bc045..e86894a2 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -1,3 +1,5 @@ +import NIOCore + extension Float: PSQLCodable { var psqlType: PSQLDataType { .float4 diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift index 11c2c46c..2c421e92 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift @@ -1,3 +1,5 @@ +import NIOCore + extension UInt8: PSQLCodable { var psqlType: PSQLDataType { .char diff --git a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift index 8ca5f08c..0a321003 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift @@ -1,3 +1,4 @@ +import NIOCore import NIOFoundationCompat import class Foundation.JSONEncoder import class Foundation.JSONDecoder diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index 0005f7f8..99332221 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -1,3 +1,5 @@ +import NIOCore + extension Optional: PSQLDecodable where Wrapped: PSQLDecodable { static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Optional { preconditionFailure("This code path should never be hit.") diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift index f2096e77..02bafa39 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { var psqlType: PSQLDataType { self.rawValue.psqlType diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift index 9e325435..cff48330 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift @@ -1,3 +1,4 @@ +import NIOCore import struct Foundation.UUID extension String: PSQLCodable { diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift index fcabd094..5e259c4b 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift @@ -1,3 +1,4 @@ +import NIOCore import struct Foundation.UUID import typealias Foundation.uuid_t diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 10dd334a..3245b168 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore internal extension ByteBuffer { mutating func writeNullTerminatedString(_ string: String) { diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift index 5586c775..ef1ec2d3 100644 --- a/Sources/PostgresNIO/New/Messages/Authentication.swift +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PSQLBackendMessage { diff --git a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift index d4237498..c9db7907 100644 --- a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift +++ b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLBackendMessage { struct BackendKeyData: PayloadDecodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index f69b8d4a..110d7866 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { struct Bind { diff --git a/Sources/PostgresNIO/New/Messages/Cancel.swift b/Sources/PostgresNIO/New/Messages/Cancel.swift index 11f08855..7983e0b3 100644 --- a/Sources/PostgresNIO/New/Messages/Cancel.swift +++ b/Sources/PostgresNIO/New/Messages/Cancel.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { struct Cancel: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/Close.swift b/Sources/PostgresNIO/New/Messages/Close.swift index 47396e9b..fa755dc3 100644 --- a/Sources/PostgresNIO/New/Messages/Close.swift +++ b/Sources/PostgresNIO/New/Messages/Close.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { enum Close: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index bdb823ba..3d1b982b 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PSQLBackendMessage { diff --git a/Sources/PostgresNIO/New/Messages/Describe.swift b/Sources/PostgresNIO/New/Messages/Describe.swift index 74845050..76ba56e1 100644 --- a/Sources/PostgresNIO/New/Messages/Describe.swift +++ b/Sources/PostgresNIO/New/Messages/Describe.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { enum Describe: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift index cfe943f4..4dc5bc99 100644 --- a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift +++ b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLBackendMessage { enum Field: UInt8, Hashable { diff --git a/Sources/PostgresNIO/New/Messages/Execute.swift b/Sources/PostgresNIO/New/Messages/Execute.swift index 998bf952..f88e4482 100644 --- a/Sources/PostgresNIO/New/Messages/Execute.swift +++ b/Sources/PostgresNIO/New/Messages/Execute.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { struct Execute: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift index 36ad90f4..72ff8141 100644 --- a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift +++ b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PSQLBackendMessage { diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index 5c49440c..340492c7 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLBackendMessage { struct ParameterDescription: PayloadDecodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift index 891ea89a..0a40962b 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLBackendMessage { struct ParameterStatus: PayloadDecodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/Parse.swift b/Sources/PostgresNIO/New/Messages/Parse.swift index f72735de..1f45115e 100644 --- a/Sources/PostgresNIO/New/Messages/Parse.swift +++ b/Sources/PostgresNIO/New/Messages/Parse.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { struct Parse: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/Password.swift b/Sources/PostgresNIO/New/Messages/Password.swift index cbb464cb..e8942561 100644 --- a/Sources/PostgresNIO/New/Messages/Password.swift +++ b/Sources/PostgresNIO/New/Messages/Password.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { struct Password: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift index 61bc76b1..35529db1 100644 --- a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift +++ b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLBackendMessage { enum TransactionState: PayloadDecodable, RawRepresentable { typealias RawValue = UInt8 diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index fdb495a5..20845958 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLBackendMessage { struct RowDescription: PayloadDecodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift b/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift index 5762f88b..916e99ce 100644 --- a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift +++ b/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { struct SASLInitialResponse: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/SASLResponse.swift b/Sources/PostgresNIO/New/Messages/SASLResponse.swift index 6391bdb1..8785edcd 100644 --- a/Sources/PostgresNIO/New/Messages/SASLResponse.swift +++ b/Sources/PostgresNIO/New/Messages/SASLResponse.swift @@ -1,3 +1,5 @@ +import NIOCore + extension PSQLFrontendMessage { struct SASLResponse: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/SSLRequest.swift b/Sources/PostgresNIO/New/Messages/SSLRequest.swift index 19ec011c..8995804a 100644 --- a/Sources/PostgresNIO/New/Messages/SSLRequest.swift +++ b/Sources/PostgresNIO/New/Messages/SSLRequest.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PSQLFrontendMessage { /// A message asking the PostgreSQL server if TLS is supported diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift index 394efdd7..7e6e7db7 100644 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ b/Sources/PostgresNIO/New/Messages/Startup.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore extension PSQLFrontendMessage { struct Startup: PayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/PSQL+JSON.swift b/Sources/PostgresNIO/New/PSQL+JSON.swift index 7d24b34a..564a2cc1 100644 --- a/Sources/PostgresNIO/New/PSQL+JSON.swift +++ b/Sources/PostgresNIO/New/PSQL+JSON.swift @@ -1,6 +1,7 @@ +import NIOCore +import NIOFoundationCompat import class Foundation.JSONEncoder import class Foundation.JSONDecoder -import NIOFoundationCompat protocol PSQLJSONEncoder { func encode(_ value: T, into buffer: inout ByteBuffer) throws diff --git a/Sources/PostgresNIO/New/PSQLBackendMessage.swift b/Sources/PostgresNIO/New/PSQLBackendMessage.swift index 24845f7b..12cd7d27 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessage.swift +++ b/Sources/PostgresNIO/New/PSQLBackendMessage.swift @@ -1,4 +1,5 @@ -import struct Foundation.Data +import NIOCore +//import struct Foundation.Data /// A protocol to implement for all associated value in the `PSQLBackendMessage` enum diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index 84819d24..b4606639 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import NIOTLS import Crypto import Logging diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index 32e30e49..b5434edd 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -1,3 +1,5 @@ +import NIOCore + /// A type that can encode itself to a postgres wire binary representation. protocol PSQLEncodable { /// identifies the data type that we will encode into `byteBuffer` in `encode` diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 5869334b..a99253b6 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -1,4 +1,5 @@ -import NIO +import NIOCore +import NIOPosix import NIOFoundationCompat import NIOSSL import class Foundation.JSONEncoder diff --git a/Sources/PostgresNIO/New/PSQLData.swift b/Sources/PostgresNIO/New/PSQLData.swift index 522ae585..840d798a 100644 --- a/Sources/PostgresNIO/New/PSQLData.swift +++ b/Sources/PostgresNIO/New/PSQLData.swift @@ -1,3 +1,5 @@ +import NIOCore + /// The format the postgres types are encoded in on the wire. /// /// Currently there a two wire formats supported: diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 03998d4a..49825892 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -1,4 +1,4 @@ -import struct Foundation.Data +import NIOCore struct PSQLError: Error { diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift index e83e0637..2c9aeaa1 100644 --- a/Sources/PostgresNIO/New/PSQLEventsHandler.swift +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -1,4 +1,6 @@ +import NIOCore import NIOTLS +import Logging enum PSQLOutgoingEvent { /// the event we send down the channel to inform the `PSQLChannelHandler` to authenticate diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift b/Sources/PostgresNIO/New/PSQLFrontendMessage.swift index 37488b3d..ee7196c8 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessage.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore /// A wire message that is created by a Postgres client to be consumed by Postgres server. /// diff --git a/Sources/PostgresNIO/New/PSQLRows.swift b/Sources/PostgresNIO/New/PSQLRows.swift index 23efe393..a8632e7c 100644 --- a/Sources/PostgresNIO/New/PSQLRows.swift +++ b/Sources/PostgresNIO/New/PSQLRows.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import Logging final class PSQLRows { diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 07ea10ca..895201b8 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -1,3 +1,6 @@ +import Logging +import NIOCore + enum PSQLTask { case extendedQuery(ExtendedQueryContext) case preparedStatement(PrepareStatementContext) diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index db854864..7af85fd3 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -1,3 +1,5 @@ +import NIOCore + struct PostgresJSONDecoderWrapper: PSQLJSONDecoder { let downstream: PostgresJSONDecoder diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index f9823ae6..b6c0b183 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import Logging extension PostgresDatabase { diff --git a/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift b/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift index 64c9b919..77f3d034 100644 --- a/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift +++ b/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import Logging extension PostgresDatabase { diff --git a/Sources/PostgresNIO/PostgresDatabase.swift b/Sources/PostgresNIO/PostgresDatabase.swift index 3f6f826f..64e44abb 100644 --- a/Sources/PostgresNIO/PostgresDatabase.swift +++ b/Sources/PostgresNIO/PostgresDatabase.swift @@ -1,3 +1,6 @@ +import NIOCore +import Logging + public protocol PostgresDatabase { var logger: Logger { get } var eventLoop: EventLoop { get } diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 9c388b65..4224d53f 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,3 +1,4 @@ +// TODO: Remove this with the next major release! @_exported import NIO @_exported import NIOSSL @_exported import struct Logging.Logger diff --git a/Sources/PostgresNIO/Utilities/NIOUtils.swift b/Sources/PostgresNIO/Utilities/NIOUtils.swift index 1523b4f5..75ab8c20 100644 --- a/Sources/PostgresNIO/Utilities/NIOUtils.swift +++ b/Sources/PostgresNIO/Utilities/NIOUtils.swift @@ -1,5 +1,5 @@ import Foundation -import NIO +import NIOCore internal extension ByteBuffer { mutating func readInteger(endianness: Endianness = .big, as rawRepresentable: E.Type) -> E? where E: RawRepresentable, E.RawValue: FixedWidthInteger { diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 2d9767c9..d3b25ef7 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -1,6 +1,8 @@ +import XCTest import Logging @testable import PostgresNIO -import XCTest +import NIOCore +import NIOPosix import NIOTestUtils final class IntegrationTests: XCTestCase { diff --git a/Tests/IntegrationTests/PerformanceTests.swift b/Tests/IntegrationTests/PerformanceTests.swift index a26748c4..7e74a595 100644 --- a/Tests/IntegrationTests/PerformanceTests.swift +++ b/Tests/IntegrationTests/PerformanceTests.swift @@ -1,6 +1,8 @@ +import XCTest import Logging +import NIOCore +import NIOPosix import PostgresNIO -import XCTest import NIOTestUtils final class PerformanceTests: XCTestCase { diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index d4095658..edc915dd 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1,6 +1,8 @@ import Logging @testable import PostgresNIO import XCTest +import NIOCore +import NIOPosix import NIOTestUtils final class PostgresNIOTests: XCTestCase { diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index 3c762219..0964f947 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -1,5 +1,6 @@ -import PostgresNIO import XCTest +import PostgresNIO +import NIOCore import Logging #if canImport(Darwin) import Darwin.C diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index a1cfbb5c..b503f1ad 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class AuthenticationStateMachineTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index b2ee2652..eb282444 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -1,6 +1,8 @@ import XCTest @testable import PostgresNIO -@testable import NIO +@testable import NIOCore +import NIOPosix +import NIOSSL class ConnectionStateMachineTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 32766aab..6cb48324 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -1,4 +1,7 @@ import XCTest +import NIOCore +import NIOEmbedded +import Logging @testable import PostgresNIO class ExtendedQueryStateMachineTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index d822e687..9b88af9a 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOEmbedded @testable import PostgresNIO class PrepareStatementStateMachineTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index e771b080..1079205e 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class Array_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index 00ec1bf5..f7d40834 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class Bool_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index 914efbff..7d58b660 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class Bytes_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 38f6dd87..aae7ad8b 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class Date_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift index 143f907c..33b8c0da 100644 --- a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class Float_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 883a3dfc..325641e8 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class JSON_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift index 3900505b..ead0a1b4 100644 --- a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class Optional_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift index 1e472366..cf233890 100644 --- a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class RawRepresentable_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index c8cce8f1..304bb7d6 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class String_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 7325955b..8b1be81e 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class UUID_PSQLCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift index 6b5aa0ac..9d1cfb81 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore @testable import PostgresNIO extension ByteBuffer { diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index d99c4280..dc7aaa7b 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -1,4 +1,5 @@ import class Foundation.JSONEncoder +import NIOCore @testable import PostgresNIO extension ConnectionStateMachine.ConnectionAction: Equatable { diff --git a/Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift b/Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift deleted file mode 100644 index fdada802..00000000 --- a/Tests/PostgresNIOTests/New/Extensions/LoggingUtils.swift +++ /dev/null @@ -1,2 +0,0 @@ -import Logging - diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessage+Equatable.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessage+Equatable.swift index 6ab452b7..36453b7c 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessage+Equatable.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessage+Equatable.swift @@ -1,6 +1,7 @@ +import NIOCore +@testable import PostgresNIO import class Foundation.JSONEncoder import class Foundation.JSONDecoder -@testable import PostgresNIO extension PSQLFrontendMessage.Bind: Equatable { public static func ==(lhs: Self, rhs: Self) -> Bool { diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 60c90703..63281b28 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class AuthenticationTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift index efcfe358..197c49a8 100644 --- a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class BackendKeyDataTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index de6ace78..43e7e7cf 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class BindTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift index 333e3644..80ac98d0 100644 --- a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class CancelTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index 90cd989e..c75fe78a 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class CloseTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index b55dd1b6..497534a9 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class DataRowTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index 7566daf8..777e0769 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class DescribeTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift index a2b4113e..c78f48cc 100644 --- a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class ErrorResponseTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 1c68f3be..177e5fcb 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class ExecuteTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift index cb7f37c5..b1ff469a 100644 --- a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class NotificationResponseTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift index af316a15..9ff80abf 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class ParameterDescriptionTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift index 2d256dc9..b0180725 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class ParameterStatusTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 223d6002..146fc57f 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -1,6 +1,5 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore @testable import PostgresNIO class ParseTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 7c8f13c4..75ab3a85 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class PasswordTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift index 029f627a..41527567 100644 --- a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class ReadyForQueryTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index bbcb103d..412bfb9d 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -1,6 +1,6 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore +import NIOTestUtils @testable import PostgresNIO class RowDescriptionTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index 0c2c5823..c846c260 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -1,6 +1,5 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore @testable import PostgresNIO class SASLInitialResponseTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index 28dd46c8..e4556ac2 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -1,6 +1,5 @@ -import NIO -import NIOTestUtils import XCTest +import NIOCore @testable import PostgresNIO class SASLResponseTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift index 917ea24d..7f8e57f4 100644 --- a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class SSLRequestTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 73667585..3a386bd3 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class StartupTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index 717fa455..acfef769 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -1,4 +1,5 @@ -import NIO +import NIOCore +import NIOEmbedded import NIOTestUtils import XCTest @testable import PostgresNIO diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index 929337ba..23ebbbde 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -1,6 +1,8 @@ import XCTest -import NIO +import NIOCore import NIOTLS +import NIOSSL +import NIOEmbedded @testable import PostgresNIO class PSQLChannelHandlerTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift index 2a1c7e97..708c6c0e 100644 --- a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift @@ -1,4 +1,5 @@ -import NIO +import NIOCore +import NIOPosix import XCTest import Logging @testable import PostgresNIO diff --git a/Tests/PostgresNIOTests/New/PSQLDataTests.swift b/Tests/PostgresNIOTests/New/PSQLDataTests.swift index 55873310..c76b8d07 100644 --- a/Tests/PostgresNIOTests/New/PSQLDataTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLDataTests.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import XCTest @testable import PostgresNIO diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift index 182a8678..69fa1374 100644 --- a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -1,4 +1,5 @@ import XCTest +import NIOCore @testable import PostgresNIO class PSQLFrontendMessageTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift b/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift index e190c740..b069b4f0 100644 --- a/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLMetadataTests.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import XCTest @testable import PostgresNIO From 6a9eb6f15235e844c95382959a88253442168a42 Mon Sep 17 00:00:00 2001 From: "Juan A. Reyes" <59104004+jareyesda@users.noreply.github.com> Date: Wed, 18 Aug 2021 06:08:31 -0400 Subject: [PATCH 014/292] Fix NIOSSL deprecation warnings - Package.swift had the `NIOSSL` package raised to 2.14.1 from the previous 2.12.0. - All instances of `TLSConfiguration.forClient` have been replaced with `TLSConfiguration.makeClientConfiguration` Co-authored-by: Fabian Fett --- Package.swift | 2 +- Tests/IntegrationTests/PostgresNIOTests.swift | 4 ++-- Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Package.swift b/Package.swift index 1c393213..c46089e0 100644 --- a/Package.swift +++ b/Package.swift @@ -14,7 +14,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.32.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.12.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), .package(url: "https://github.com/apple/swift-crypto.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index edc915dd..308ecfee 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -680,7 +680,7 @@ final class PostgresNIOTests: XCTestCase { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.connect( to: SocketAddress.makeAddressResolvingHost("elmer.db.elephantsql.com", port: 5432), - tlsConfiguration: .forClient(certificateVerification: .none), + tlsConfiguration: .makeClientConfiguration(), serverHostname: "elmer.db.elephantsql.com", on: eventLoop ).wait()) @@ -709,7 +709,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertThrowsError( try PostgresConnection.connect( to: SocketAddress.makeAddressResolvingHost("elmer.db.elephantsql.com", port: 5432), - tlsConfiguration: .forClient(certificateVerification: .fullVerification), + tlsConfiguration: .makeClientConfiguration(), serverHostname: "34.228.73.168", on: eventLoop ).wait() diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index 23ebbbde..b0456d49 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -34,7 +34,7 @@ class PSQLChannelHandlerTests: XCTestCase { func testEstablishSSLCallbackIsCalledIfSSLIsSupported() { var config = self.testConnectionConfiguration() - config.tlsConfiguration = .forClient(certificateVerification: .none) + config.tlsConfiguration = .makeClientConfiguration() var addSSLCallbackIsHit = false let handler = PSQLChannelHandler(authentification: config.authentication) { channel in addSSLCallbackIsHit = true @@ -72,7 +72,7 @@ class PSQLChannelHandlerTests: XCTestCase { func testSSLUnsupportedClosesConnection() { var config = self.testConnectionConfiguration() - config.tlsConfiguration = .forClient() + config.tlsConfiguration = .makeClientConfiguration() let handler = PSQLChannelHandler(authentification: config.authentication) { channel in XCTFail("This callback should never be exectuded") From 7acca81e6dc0dc98004c65fd25b0e7e172eaa4d1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 18 Sep 2021 13:36:24 +0200 Subject: [PATCH 015/292] PSQLBackendMessageDecoder is a SingleStepDecoder (#174) ### Motivation We want to use the `PSQLBackendMessageDecoder` as `NIOSingleStepByteToMessageDecoder` in the future. ### Changes - Rename `PSQLBackendMessage.Decoder` to `PSQLBackendMessageDecoder`. Namespacing the Decoder in its own MessageType was a stupid idea. Sorry. Reverting this now. - `PSQLBackendMessageDecoder` get's its own file. Implementation copy and pasted. - `PSQLBackendMessageDecoder` implements `NIOSingleStepByteToMessageDecoder` protocol, which has an auto-conformance to `ByteToMessageDecoder` - Move `PSQLBackendMessage.ensureAtLeastNBytesRemaining` into an internal extension on `ByteBuffer` ### Result - Cleaner, less clever code. --- .../New/Messages/Authentication.swift | 6 +- .../New/Messages/BackendKeyData.swift | 2 +- .../PostgresNIO/New/Messages/DataRow.swift | 6 +- .../New/Messages/ErrorResponse.swift | 4 +- .../New/Messages/NotificationResponse.swift | 6 +- .../New/Messages/ParameterDescription.swift | 6 +- .../New/Messages/ParameterStatus.swift | 4 +- .../New/Messages/ReadyForQuery.swift | 6 +- .../New/Messages/RowDescription.swift | 10 +- .../PostgresNIO/New/PSQLBackendMessage.swift | 222 +----------------- .../New/PSQLBackendMessageDecoder.swift | 207 ++++++++++++++++ Sources/PostgresNIO/New/PSQLConnection.swift | 2 +- Sources/PostgresNIO/New/PSQLError.swift | 4 +- .../New/Messages/AuthenticationTests.swift | 2 +- .../New/Messages/BackendKeyDataTests.swift | 6 +- .../New/Messages/DataRowTests.swift | 2 +- .../New/Messages/ErrorResponseTests.swift | 2 +- .../Messages/NotificationResponseTests.swift | 10 +- .../Messages/ParameterDescriptionTests.swift | 10 +- .../New/Messages/ParameterStatusTests.swift | 10 +- .../New/Messages/ReadyForQueryTests.swift | 10 +- .../New/Messages/RowDescriptionTests.swift | 18 +- .../New/PSQLBackendMessageTests.swift | 20 +- 23 files changed, 286 insertions(+), 289 deletions(-) create mode 100644 Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift index ef1ec2d3..5ce5b857 100644 --- a/Sources/PostgresNIO/New/Messages/Authentication.swift +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -16,7 +16,7 @@ extension PSQLBackendMessage { case saslFinal(data: ByteBuffer) static func decode(from buffer: inout ByteBuffer) throws -> Self { - try PSQLBackendMessage.ensureAtLeastNBytesRemaining(2, in: buffer) + try buffer.ensureAtLeastNBytesRemaining(2) // we have at least two bytes remaining, therefore we can force unwrap this read. let authID = buffer.readInteger(as: Int32.self)! @@ -29,7 +29,7 @@ extension PSQLBackendMessage { case 3: return .plaintext case 5: - try PSQLBackendMessage.ensureExactNBytesRemaining(4, in: buffer) + try buffer.ensureExactNBytesRemaining(4) let salt1 = buffer.readInteger(as: UInt8.self)! let salt2 = buffer.readInteger(as: UInt8.self)! let salt3 = buffer.readInteger(as: UInt8.self)! @@ -59,7 +59,7 @@ extension PSQLBackendMessage { let data = buffer.readSlice(length: buffer.readableBytes)! return .saslFinal(data: data) default: - throw PartialDecodingError.unexpectedValue(value: authID) + throw PSQLPartialDecodingError.unexpectedValue(value: authID) } } diff --git a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift index c9db7907..dfb5738e 100644 --- a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift +++ b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift @@ -7,7 +7,7 @@ extension PSQLBackendMessage { let secretKey: Int32 static func decode(from buffer: inout ByteBuffer) throws -> Self { - try PSQLBackendMessage.ensureExactNBytesRemaining(8, in: buffer) + try buffer.ensureExactNBytesRemaining(8) // We have verified the correct length before, this means we have exactly eight bytes // to read. If we have enough readable bytes, a read of Int32 should always succeed. diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index 3d1b982b..3047ccc2 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -7,14 +7,14 @@ extension PSQLBackendMessage { var columns: [ByteBuffer?] static func decode(from buffer: inout ByteBuffer) throws -> Self { - try PSQLBackendMessage.ensureAtLeastNBytesRemaining(2, in: buffer) + try buffer.ensureAtLeastNBytesRemaining(2) let columnCount = buffer.readInteger(as: Int16.self)! var result = [ByteBuffer?]() result.reserveCapacity(Int(columnCount)) for _ in 0..= 0 else { @@ -22,7 +22,7 @@ extension PSQLBackendMessage { continue } - try PSQLBackendMessage.ensureAtLeastNBytesRemaining(bufferLength, in: buffer) + try buffer.ensureAtLeastNBytesRemaining(bufferLength) let columnBuffer = buffer.readSlice(length: Int(bufferLength))! result.append(columnBuffer) diff --git a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift index 4dc5bc99..254cdf0f 100644 --- a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift +++ b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift @@ -112,13 +112,13 @@ extension PSQLBackendMessage.PayloadDecodable where Self: PSQLMessageNotice { break } guard let field = PSQLBackendMessage.Field(rawValue: id) else { - throw PSQLBackendMessage.PartialDecodingError.valueNotRawRepresentable( + throw PSQLPartialDecodingError.valueNotRawRepresentable( value: id, asType: PSQLBackendMessage.Field.self) } guard let string = buffer.readNullTerminatedString() else { - throw PSQLBackendMessage.PartialDecodingError.fieldNotDecodable(type: String.self) + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } fields[field] = string } diff --git a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift index 72ff8141..b1430e2a 100644 --- a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift +++ b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift @@ -8,14 +8,14 @@ extension PSQLBackendMessage { let payload: String static func decode(from buffer: inout ByteBuffer) throws -> PSQLBackendMessage.NotificationResponse { - try PSQLBackendMessage.ensureAtLeastNBytesRemaining(6, in: buffer) + try buffer.ensureAtLeastNBytesRemaining(6) let backendPID = buffer.readInteger(as: Int32.self)! guard let channel = buffer.readNullTerminatedString() else { - throw PartialDecodingError.fieldNotDecodable(type: String.self) + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } guard let payload = buffer.readNullTerminatedString() else { - throw PartialDecodingError.fieldNotDecodable(type: String.self) + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } return NotificationResponse(backendPID: backendPID, channel: channel, payload: payload) diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index 340492c7..fdf64aad 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -7,14 +7,14 @@ extension PSQLBackendMessage { var dataTypes: [PSQLDataType] static func decode(from buffer: inout ByteBuffer) throws -> Self { - try PSQLBackendMessage.ensureAtLeastNBytesRemaining(2, in: buffer) + try buffer.ensureAtLeastNBytesRemaining(2) let parameterCount = buffer.readInteger(as: Int16.self)! guard parameterCount >= 0 else { - throw PartialDecodingError.integerMustBePositiveOrNull(parameterCount) + throw PSQLPartialDecodingError.integerMustBePositiveOrNull(parameterCount) } - try PSQLBackendMessage.ensureExactNBytesRemaining(Int(parameterCount) * 4, in: buffer) + try buffer.ensureExactNBytesRemaining(Int(parameterCount) * 4) var result = [PSQLDataType]() result.reserveCapacity(Int(parameterCount)) diff --git a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift index 0a40962b..89dd1d6d 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift @@ -11,11 +11,11 @@ extension PSQLBackendMessage { static func decode(from buffer: inout ByteBuffer) throws -> Self { guard let name = buffer.readNullTerminatedString() else { - throw PartialDecodingError.fieldNotDecodable(type: String.self) + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } guard let value = buffer.readNullTerminatedString() else { - throw PartialDecodingError.fieldNotDecodable(type: String.self) + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } return ParameterStatus(parameter: name, value: value) diff --git a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift index 35529db1..20420763 100644 --- a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift +++ b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift @@ -33,14 +33,12 @@ extension PSQLBackendMessage { } static func decode(from buffer: inout ByteBuffer) throws -> Self { - guard buffer.readableBytes == 1 else { - throw PartialDecodingError.expectedExactlyNRemainingBytes(1, actual: buffer.readableBytes) - } + try buffer.ensureExactNBytesRemaining(1) // Exactly one byte is readable. For this reason, we can force unwrap the UInt8 below let value = buffer.readInteger(as: UInt8.self)! guard let state = Self.init(rawValue: value) else { - throw PartialDecodingError.valueNotRawRepresentable(value: value, asType: TransactionState.self) + throw PSQLPartialDecodingError.valueNotRawRepresentable(value: value, asType: TransactionState.self) } return state diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index 20845958..90f0bbac 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -31,11 +31,11 @@ extension PSQLBackendMessage { } static func decode(from buffer: inout ByteBuffer) throws -> Self { - try PSQLBackendMessage.ensureAtLeastNBytesRemaining(2, in: buffer) + try buffer.ensureAtLeastNBytesRemaining(2) let columnCount = buffer.readInteger(as: Int16.self)! guard columnCount >= 0 else { - throw PartialDecodingError.integerMustBePositiveOrNull(columnCount) + throw PSQLPartialDecodingError.integerMustBePositiveOrNull(columnCount) } var result = [Column]() @@ -43,10 +43,10 @@ extension PSQLBackendMessage { for _ in 0.. DecodingState { - // make sure we have at least one byte to read - guard buffer.readableBytes > 0 else { - return .needMoreData - } - - if !self.hasAlreadyReceivedBytes { - // We have not received any bytes yet! Let's peek at the first message id. If it - // is a "S" or "N" we assume that it is connected to an SSL upgrade request. All - // other messages that we expect now, don't start with either "S" or "N" - - // we made sure, we have at least one byte available, above, thus force unwrap is okay - let firstByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self)! - - switch firstByte { - case UInt8(ascii: "S"): - // mark byte as read - buffer.moveReaderIndex(forwardBy: 1) - context.fireChannelRead(NIOAny(PSQLBackendMessage.sslSupported)) - self.hasAlreadyReceivedBytes = true - return .continue - case UInt8(ascii: "N"): - // mark byte as read - buffer.moveReaderIndex(forwardBy: 1) - context.fireChannelRead(NIOAny(PSQLBackendMessage.sslUnsupported)) - self.hasAlreadyReceivedBytes = true - return .continue - default: - self.hasAlreadyReceivedBytes = true - } - } - - // all other packages have an Int32 after the identifier that determines their length. - // do we have enough bytes for that? - guard buffer.readableBytes >= 5 else { - return .needMoreData - } - - let idByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self)! - let length = buffer.getInteger(at: buffer.readerIndex + 1, as: Int32.self)! - - guard length + 1 <= buffer.readableBytes else { - return .needMoreData - } - - // At this point we are sure, that we have enough bytes to decode the next message. - // 1. Create a byteBuffer that represents exactly the next message. This can be force - // unwrapped, since it was verified that enough bytes are available. - let completeMessageBuffer = buffer.readSlice(length: 1 + Int(length))! - - // 2. make sure we have a known message identifier - guard let messageID = PSQLBackendMessage.ID(rawValue: idByte) else { - throw DecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer) - } - - // 3. decode the message - do { - // get a mutable byteBuffer copy - var slice = completeMessageBuffer - // move reader index forward by five bytes - slice.moveReaderIndex(forwardBy: 5) - - let message = try PSQLBackendMessage.decode(from: &slice, for: messageID) - context.fireChannelRead(NIOAny(message)) - } catch let error as PartialDecodingError { - throw DecodingError.withPartialError(error, messageID: messageID, messageBytes: completeMessageBuffer) - } catch { - preconditionFailure("Expected to only see `PartialDecodingError`s here.") - } - - return .continue - } - } -} - extension PSQLBackendMessage: CustomDebugStringConvertible { var debugDescription: String { switch self { @@ -363,129 +281,3 @@ extension PSQLBackendMessage: CustomDebugStringConvertible { } } } - -extension PSQLBackendMessage { - - /// An error representing a failure to decode [a Postgres wire message](https://www.postgresql.org/docs/13/protocol-message-formats.html) - /// to the Swift structure `PSQLBackendMessage`. - /// - /// If you encounter a `DecodingError` when using a trusted Postgres server please make to file an issue at: - /// [https://github.com/vapor/postgres-nio/issues](https://github.com/vapor/postgres-nio/issues) - struct DecodingError: Error { - - /// The backend message ID bytes - let messageID: UInt8 - - /// The backend message's payload encoded in base64 - let payload: String - - /// A textual description of the error - let description: String - - /// The file this error was thrown in - let file: String - - /// The line in `file` this error was thrown - let line: Int - - static func withPartialError( - _ partialError: PartialDecodingError, - messageID: PSQLBackendMessage.ID, - messageBytes: ByteBuffer) -> Self - { - var byteBuffer = messageBytes - let data = byteBuffer.readData(length: byteBuffer.readableBytes)! - - return DecodingError( - messageID: messageID.rawValue, - payload: data.base64EncodedString(), - description: partialError.description, - file: partialError.file, - line: partialError.line) - } - - static func unknownMessageIDReceived( - messageID: UInt8, - messageBytes: ByteBuffer, - file: String = #file, - line: Int = #line) -> Self - { - var byteBuffer = messageBytes - let data = byteBuffer.readData(length: byteBuffer.readableBytes)! - - return DecodingError( - messageID: messageID, - payload: data.base64EncodedString(), - description: "Received a message with messageID '\(Character(UnicodeScalar(messageID)))'. There is no message type associated with this message identifier.", - file: file, - line: line) - } - - } - - struct PartialDecodingError: Error { - /// A textual description of the error - let description: String - - /// The file this error was thrown in - let file: String - - /// The line in `file` this error was thrown - let line: Int - - static func valueNotRawRepresentable( - value: Target.RawValue, - asType: Target.Type, - file: String = #file, - line: Int = #line) -> Self - { - return PartialDecodingError( - description: "Can not represent '\(value)' with type '\(asType)'.", - file: file, line: line) - } - - static func unexpectedValue(value: Any, file: String = #file, line: Int = #line) -> Self { - return PartialDecodingError( - description: "Value '\(value)' is not expected.", - file: file, line: line) - } - - static func expectedAtLeastNRemainingBytes(_ expected: Int, actual: Int, file: String = #file, line: Int = #line) -> Self { - return PartialDecodingError( - description: "Expected at least '\(expected)' remaining bytes. But only found \(actual).", - file: file, line: line) - } - - static func expectedExactlyNRemainingBytes(_ expected: Int, actual: Int, file: String = #file, line: Int = #line) -> Self { - return PartialDecodingError( - description: "Expected exactly '\(expected)' remaining bytes. But found \(actual).", - file: file, line: line) - } - - static func fieldNotDecodable(type: Any.Type, file: String = #file, line: Int = #line) -> Self { - return PartialDecodingError( - description: "Could not read '\(type)' from ByteBuffer.", - file: file, line: line) - } - - static func integerMustBePositiveOrNull(_ actual: Number, file: String = #file, line: Int = #line) -> Self { - return PartialDecodingError( - description: "Expected the integer to be positive or null, but got \(actual).", - file: file, line: line) - } - } - - @inline(__always) - static func ensureAtLeastNBytesRemaining(_ n: Int, in buffer: ByteBuffer, file: String = #file, line: Int = #line) throws { - guard buffer.readableBytes >= n else { - throw PartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes, file: file, line: line) - } - } - - @inline(__always) - static func ensureExactNBytesRemaining(_ n: Int, in buffer: ByteBuffer, file: String = #file, line: Int = #line) throws { - guard buffer.readableBytes == n else { - throw PartialDecodingError.expectedExactlyNRemainingBytes(n, actual: buffer.readableBytes, file: file, line: line) - } - } -} diff --git a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift new file mode 100644 index 00000000..58f5c460 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift @@ -0,0 +1,207 @@ +struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { + typealias InboundOut = PSQLBackendMessage + + private(set) var hasAlreadyReceivedBytes: Bool + + init(hasAlreadyReceivedBytes: Bool = false) { + self.hasAlreadyReceivedBytes = hasAlreadyReceivedBytes + } + + mutating func decode(buffer: inout ByteBuffer) throws -> PSQLBackendMessage? { + // make sure we have at least one byte to read + guard buffer.readableBytes > 0 else { + return nil + } + + if !self.hasAlreadyReceivedBytes { + // We have not received any bytes yet! Let's peek at the first message id. If it + // is a "S" or "N" we assume that it is connected to an SSL upgrade request. All + // other messages that we expect now, don't start with either "S" or "N" + + // we made sure, we have at least one byte available, above, thus force unwrap is okay + let firstByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self)! + + switch firstByte { + case UInt8(ascii: "S"): + // mark byte as read + buffer.moveReaderIndex(forwardBy: 1) + self.hasAlreadyReceivedBytes = true + return .sslSupported + case UInt8(ascii: "N"): + // mark byte as read + buffer.moveReaderIndex(forwardBy: 1) + self.hasAlreadyReceivedBytes = true + return .sslUnsupported + default: + self.hasAlreadyReceivedBytes = true + } + } + + // all other packages have an Int32 after the identifier that determines their length. + // do we have enough bytes for that? + guard buffer.readableBytes >= 5 else { + return nil + } + + let idByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self)! + let length = buffer.getInteger(at: buffer.readerIndex + 1, as: Int32.self)! + + guard length + 1 <= buffer.readableBytes else { + return nil + } + + // At this point we are sure, that we have enough bytes to decode the next message. + // 1. Create a byteBuffer that represents exactly the next message. This can be force + // unwrapped, since it was verified that enough bytes are available. + let completeMessageBuffer = buffer.readSlice(length: 1 + Int(length))! + + // 2. make sure we have a known message identifier + guard let messageID = PSQLBackendMessage.ID(rawValue: idByte) else { + throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer) + } + + // 3. decode the message + do { + // get a mutable byteBuffer copy + var slice = completeMessageBuffer + // move reader index forward by five bytes + slice.moveReaderIndex(forwardBy: 5) + + return try PSQLBackendMessage.decode(from: &slice, for: messageID) + } catch let error as PSQLPartialDecodingError { + throw PSQLDecodingError.withPartialError(error, messageID: messageID, messageBytes: completeMessageBuffer) + } catch { + preconditionFailure("Expected to only see `PartialDecodingError`s here.") + } + } + + mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PSQLBackendMessage? { + try self.decode(buffer: &buffer) + } +} + + + +/// An error representing a failure to decode [a Postgres wire message](https://www.postgresql.org/docs/13/protocol-message-formats.html) +/// to the Swift structure `PSQLBackendMessage`. +/// +/// If you encounter a `DecodingError` when using a trusted Postgres server please make to file an issue at: +/// [https://github.com/vapor/postgres-nio/issues](https://github.com/vapor/postgres-nio/issues) +struct PSQLDecodingError: Error { + + /// The backend message ID bytes + let messageID: UInt8 + + /// The backend message's payload encoded in base64 + let payload: String + + /// A textual description of the error + let description: String + + /// The file this error was thrown in + let file: String + + /// The line in `file` this error was thrown + let line: Int + + static func withPartialError( + _ partialError: PSQLPartialDecodingError, + messageID: PSQLBackendMessage.ID, + messageBytes: ByteBuffer) -> Self + { + var byteBuffer = messageBytes + let data = byteBuffer.readData(length: byteBuffer.readableBytes)! + + return PSQLDecodingError( + messageID: messageID.rawValue, + payload: data.base64EncodedString(), + description: partialError.description, + file: partialError.file, + line: partialError.line) + } + + static func unknownMessageIDReceived( + messageID: UInt8, + messageBytes: ByteBuffer, + file: String = #file, + line: Int = #line) -> Self + { + var byteBuffer = messageBytes + let data = byteBuffer.readData(length: byteBuffer.readableBytes)! + + return PSQLDecodingError( + messageID: messageID, + payload: data.base64EncodedString(), + description: "Received a message with messageID '\(Character(UnicodeScalar(messageID)))'. There is no message type associated with this message identifier.", + file: file, + line: line) + } + +} + +struct PSQLPartialDecodingError: Error { + /// A textual description of the error + let description: String + + /// The file this error was thrown in + let file: String + + /// The line in `file` this error was thrown + let line: Int + + static func valueNotRawRepresentable( + value: Target.RawValue, + asType: Target.Type, + file: String = #file, + line: Int = #line) -> Self + { + return PSQLPartialDecodingError( + description: "Can not represent '\(value)' with type '\(asType)'.", + file: file, line: line) + } + + static func unexpectedValue(value: Any, file: String = #file, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Value '\(value)' is not expected.", + file: file, line: line) + } + + static func expectedAtLeastNRemainingBytes(_ expected: Int, actual: Int, file: String = #file, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Expected at least '\(expected)' remaining bytes. But only found \(actual).", + file: file, line: line) + } + + static func expectedExactlyNRemainingBytes(_ expected: Int, actual: Int, file: String = #file, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Expected exactly '\(expected)' remaining bytes. But found \(actual).", + file: file, line: line) + } + + static func fieldNotDecodable(type: Any.Type, file: String = #file, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Could not read '\(type)' from ByteBuffer.", + file: file, line: line) + } + + static func integerMustBePositiveOrNull(_ actual: Number, file: String = #file, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Expected the integer to be positive or null, but got \(actual).", + file: file, line: line) + } +} + +extension ByteBuffer { + func ensureAtLeastNBytesRemaining(_ n: Int, file: String = #file, line: Int = #line) throws { + guard self.readableBytes >= n else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: self.readableBytes, file: file, line: line) + } + } + + func ensureExactNBytesRemaining(_ n: Int, file: String = #file, line: Int = #line) throws { + guard self.readableBytes == n else { + throw PSQLPartialDecodingError.expectedExactlyNRemainingBytes(n, actual: self.readableBytes, file: file, line: line) + } + } +} + diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index a99253b6..4689ed1a 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -214,7 +214,7 @@ final class PSQLConnection { }.flatMap { address -> EventLoopFuture in let bootstrap = ClientBootstrap(group: eventLoop) .channelInitializer { channel in - let decoder = ByteToMessageHandler(PSQLBackendMessage.Decoder()) + let decoder = ByteToMessageHandler(PSQLBackendMessageDecoder()) var configureSSLCallback: ((Channel) throws -> ())? = nil if let tlsConfiguration = configuration.tlsConfiguration { diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 49825892..0cadc9ee 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -6,7 +6,7 @@ struct PSQLError: Error { case sslUnsupported case failedToAddSSLHandler(underlying: Error) case server(PSQLBackendMessage.ErrorResponse) - case decoding(PSQLBackendMessage.DecodingError) + case decoding(PSQLDecodingError) case unexpectedBackendMessage(PSQLBackendMessage) case unsupportedAuthMechanism(PSQLAuthScheme) case authMechanismRequiresPassword @@ -39,7 +39,7 @@ struct PSQLError: Error { Self.init(.server(message)) } - static func decoding(_ error: PSQLBackendMessage.DecodingError) -> PSQLError { + static func decoding(_ error: PSQLDecodingError) -> PSQLError { Self.init(.decoding(error)) } diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 63281b28..9091b3cf 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -58,6 +58,6 @@ class AuthenticationTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } } diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift index 197c49a8..eca5ba02 100644 --- a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -16,7 +16,7 @@ class BackendKeyDataTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expectedInOuts, - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } func testDecodeInvalidLength() { @@ -32,8 +32,8 @@ class BackendKeyDataTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expected, - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + XCTAssert($0 is PSQLDecodingError) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index 497534a9..af9ee3f2 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -32,6 +32,6 @@ class DataRowTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expectedInOuts, - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } } diff --git a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift index c78f48cc..bbc945e4 100644 --- a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift @@ -30,6 +30,6 @@ class ErrorResponseTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expectedInOuts, - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } } diff --git a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift index b1ff469a..39fbb220 100644 --- a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift @@ -27,7 +27,7 @@ class NotificationResponseTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } func testDecodeFailureBecauseOfMissingNullTermination() { @@ -40,8 +40,8 @@ class NotificationResponseTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } @@ -55,8 +55,8 @@ class NotificationResponseTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift index 9ff80abf..ebc80a8e 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift @@ -27,7 +27,7 @@ class ParameterDescriptionTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } func testDecodeWithNegativeCount() { @@ -43,8 +43,8 @@ class ParameterDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } @@ -62,8 +62,8 @@ class ParameterDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift index b0180725..db4963e0 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift @@ -42,7 +42,7 @@ class ParameterStatusTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } func testDecodeFailureBecauseOfMissingNullTermination() { @@ -54,8 +54,8 @@ class ParameterStatusTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } @@ -68,8 +68,8 @@ class ParameterStatusTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift index 41527567..55a2c1e7 100644 --- a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift @@ -33,7 +33,7 @@ class ReadyForQueryTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } @@ -47,8 +47,8 @@ class ReadyForQueryTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } @@ -61,8 +61,8 @@ class ReadyForQueryTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index 412bfb9d..4452ebce 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -38,7 +38,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } func testDecodeFailureBecauseOfMissingNullTerminationInColumnName() { @@ -59,8 +59,8 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } @@ -81,8 +81,8 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } @@ -104,8 +104,8 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } @@ -127,8 +127,8 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + XCTAssert($0 is PSQLDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index acfef769..049e23d1 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -97,7 +97,7 @@ class PSQLBackendMessageTests: XCTestCase { expectedMessages.append(.parameterStatus(parameterStatus)) } - let handler = ByteToMessageHandler(PSQLBackendMessage.Decoder()) + let handler = ByteToMessageHandler(PSQLBackendMessageDecoder()) let embedded = EmbeddedChannel(handler: handler) XCTAssertNoThrow(try embedded.writeInbound(buffer)) @@ -137,7 +137,7 @@ class PSQLBackendMessageTests: XCTestCase { buffer.writeInteger(0, as: UInt8.self) // signal done } - let handler = ByteToMessageHandler(PSQLBackendMessage.Decoder()) + let handler = ByteToMessageHandler(PSQLBackendMessageDecoder()) let embedded = EmbeddedChannel(handler: handler) XCTAssertNoThrow(try embedded.writeInbound(buffer)) @@ -174,7 +174,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } func testPayloadsWithoutAssociatedValuesInvalidLength() { @@ -195,8 +195,8 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + XCTAssert($0 is PSQLDecodingError) } } } @@ -222,7 +222,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(okBuffer, expected)], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) // test commandTag is not null terminated for message in expected { @@ -237,8 +237,8 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(failBuffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + XCTAssert($0 is PSQLDecodingError) } } } @@ -250,8 +250,8 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessage.Decoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PSQLBackendMessage.DecodingError) + decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + XCTAssert($0 is PSQLDecodingError) } } From 7331b52086ca69958b5c2205f1b41df9d751e3e1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 18 Sep 2021 14:09:19 +0200 Subject: [PATCH 016/292] Add conveniences to PSQLFrontendMessage (#173) ### Motivation We want to use the `NIOSingleStepByteToMessageDecoder` within the `PSQLChannelHandler` in the future. For this reason it is important that we will be able to decode PSQLFrontendMessages for tests in the future. ### Changes - Rename protocol `PSQLFrontendMessagePayloadEncodable` to `PSQLMessagePayloadEncodable` - `PSQLFrontendMessage.ID` is now `RawRepresentable` ### Result Code that makes testing easier in the future. Co-authored-by: Gwynne Raskind --- .../New/Extensions/ByteBuffer+PSQL.swift | 2 +- Sources/PostgresNIO/New/Messages/Cancel.swift | 2 +- Sources/PostgresNIO/New/Messages/Close.swift | 2 +- .../PostgresNIO/New/Messages/Describe.swift | 2 +- .../PostgresNIO/New/Messages/Execute.swift | 2 +- Sources/PostgresNIO/New/Messages/Parse.swift | 2 +- .../PostgresNIO/New/Messages/Password.swift | 2 +- .../New/Messages/SASLInitialResponse.swift | 2 +- .../New/Messages/SASLResponse.swift | 2 +- .../PostgresNIO/New/Messages/SSLRequest.swift | 2 +- .../PostgresNIO/New/Messages/Startup.swift | 2 +- .../PostgresNIO/New/PSQLFrontendMessage.swift | 44 +++++++++++++++---- .../New/Messages/BindTests.swift | 2 +- .../New/Messages/CloseTests.swift | 4 +- .../New/Messages/DescribeTests.swift | 4 +- .../New/Messages/ExecuteTests.swift | 2 +- .../New/Messages/ParseTests.swift | 2 +- .../New/Messages/PasswordTests.swift | 2 +- .../Messages/SASLInitialResponseTests.swift | 4 +- .../New/Messages/SASLResponseTests.swift | 4 +- .../New/PSQLFrontendMessageTests.swift | 28 ++++++------ 21 files changed, 73 insertions(+), 45 deletions(-) diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 3245b168..45197cc0 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -20,7 +20,7 @@ internal extension ByteBuffer { } mutating func writeFrontendMessageID(_ messageID: PSQLFrontendMessage.ID) { - self.writeInteger(messageID.byte) + self.writeInteger(messageID.rawValue) } mutating func readFloat() -> Float? { diff --git a/Sources/PostgresNIO/New/Messages/Cancel.swift b/Sources/PostgresNIO/New/Messages/Cancel.swift index 7983e0b3..d2756580 100644 --- a/Sources/PostgresNIO/New/Messages/Cancel.swift +++ b/Sources/PostgresNIO/New/Messages/Cancel.swift @@ -2,7 +2,7 @@ import NIOCore extension PSQLFrontendMessage { - struct Cancel: PayloadEncodable, Equatable { + struct Cancel: PSQLMessagePayloadEncodable, Equatable { /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same /// as any protocol version number.) diff --git a/Sources/PostgresNIO/New/Messages/Close.swift b/Sources/PostgresNIO/New/Messages/Close.swift index fa755dc3..5ed532e6 100644 --- a/Sources/PostgresNIO/New/Messages/Close.swift +++ b/Sources/PostgresNIO/New/Messages/Close.swift @@ -2,7 +2,7 @@ import NIOCore extension PSQLFrontendMessage { - enum Close: PayloadEncodable, Equatable { + enum Close: PSQLMessagePayloadEncodable, Equatable { case preparedStatement(String) case portal(String) diff --git a/Sources/PostgresNIO/New/Messages/Describe.swift b/Sources/PostgresNIO/New/Messages/Describe.swift index 76ba56e1..0a3105cc 100644 --- a/Sources/PostgresNIO/New/Messages/Describe.swift +++ b/Sources/PostgresNIO/New/Messages/Describe.swift @@ -2,7 +2,7 @@ import NIOCore extension PSQLFrontendMessage { - enum Describe: PayloadEncodable, Equatable { + enum Describe: PSQLMessagePayloadEncodable, Equatable { case preparedStatement(String) case portal(String) diff --git a/Sources/PostgresNIO/New/Messages/Execute.swift b/Sources/PostgresNIO/New/Messages/Execute.swift index f88e4482..891bd9aa 100644 --- a/Sources/PostgresNIO/New/Messages/Execute.swift +++ b/Sources/PostgresNIO/New/Messages/Execute.swift @@ -2,7 +2,7 @@ import NIOCore extension PSQLFrontendMessage { - struct Execute: PayloadEncodable, Equatable { + struct Execute: PSQLMessagePayloadEncodable, Equatable { /// The name of the portal to execute (an empty string selects the unnamed portal). let portalName: String diff --git a/Sources/PostgresNIO/New/Messages/Parse.swift b/Sources/PostgresNIO/New/Messages/Parse.swift index 1f45115e..1d0aec19 100644 --- a/Sources/PostgresNIO/New/Messages/Parse.swift +++ b/Sources/PostgresNIO/New/Messages/Parse.swift @@ -2,7 +2,7 @@ import NIOCore extension PSQLFrontendMessage { - struct Parse: PayloadEncodable, Equatable { + struct Parse: PSQLMessagePayloadEncodable, Equatable { /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). let preparedStatementName: String diff --git a/Sources/PostgresNIO/New/Messages/Password.swift b/Sources/PostgresNIO/New/Messages/Password.swift index e8942561..88e885f9 100644 --- a/Sources/PostgresNIO/New/Messages/Password.swift +++ b/Sources/PostgresNIO/New/Messages/Password.swift @@ -2,7 +2,7 @@ import NIOCore extension PSQLFrontendMessage { - struct Password: PayloadEncodable, Equatable { + struct Password: PSQLMessagePayloadEncodable, Equatable { let value: String func encode(into buffer: inout ByteBuffer) { diff --git a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift b/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift index 916e99ce..ead609c7 100644 --- a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift +++ b/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift @@ -2,7 +2,7 @@ import NIOCore extension PSQLFrontendMessage { - struct SASLInitialResponse: PayloadEncodable, Equatable { + struct SASLInitialResponse: PSQLMessagePayloadEncodable, Equatable { let saslMechanism: String let initialData: [UInt8] diff --git a/Sources/PostgresNIO/New/Messages/SASLResponse.swift b/Sources/PostgresNIO/New/Messages/SASLResponse.swift index 8785edcd..dc49a506 100644 --- a/Sources/PostgresNIO/New/Messages/SASLResponse.swift +++ b/Sources/PostgresNIO/New/Messages/SASLResponse.swift @@ -2,7 +2,7 @@ import NIOCore extension PSQLFrontendMessage { - struct SASLResponse: PayloadEncodable, Equatable { + struct SASLResponse: PSQLMessagePayloadEncodable, Equatable { let data: [UInt8] diff --git a/Sources/PostgresNIO/New/Messages/SSLRequest.swift b/Sources/PostgresNIO/New/Messages/SSLRequest.swift index 8995804a..f67f25fe 100644 --- a/Sources/PostgresNIO/New/Messages/SSLRequest.swift +++ b/Sources/PostgresNIO/New/Messages/SSLRequest.swift @@ -3,7 +3,7 @@ import NIOCore extension PSQLFrontendMessage { /// A message asking the PostgreSQL server if TLS is supported /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 - struct SSLRequest: PayloadEncodable, Equatable { + struct SSLRequest: PSQLMessagePayloadEncodable, Equatable { /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, /// and 5679 in the least significant 16 bits. let code: Int32 diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift index 7e6e7db7..148b8bc2 100644 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ b/Sources/PostgresNIO/New/Messages/Startup.swift @@ -1,7 +1,7 @@ import NIOCore extension PSQLFrontendMessage { - struct Startup: PayloadEncodable, Equatable { + struct Startup: PSQLMessagePayloadEncodable, Equatable { /// Creates a `Startup` with "3.0" as the protocol version. static func versionThree(parameters: Parameters) -> Startup { diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift b/Sources/PostgresNIO/New/PSQLFrontendMessage.swift index ee7196c8..5800b2da 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessage.swift @@ -5,8 +5,6 @@ import NIOCore /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) enum PSQLFrontendMessage { - typealias PayloadEncodable = PSQLFrontendMessagePayloadEncodable - case bind(Bind) case cancel(Cancel) case close(Close) @@ -22,7 +20,8 @@ enum PSQLFrontendMessage { case startup(Startup) case terminate - enum ID { + enum ID: UInt8, Equatable { + case bind case close case describe @@ -35,7 +34,36 @@ enum PSQLFrontendMessage { case sync case terminate - var byte: UInt8 { + init?(rawValue: UInt8) { + switch rawValue { + case UInt8(ascii: "B"): + self = .bind + case UInt8(ascii: "C"): + self = .close + case UInt8(ascii: "D"): + self = .describe + case UInt8(ascii: "E"): + self = .execute + case UInt8(ascii: "H"): + self = .flush + case UInt8(ascii: "P"): + self = .parse + case UInt8(ascii: "p"): + self = .password + case UInt8(ascii: "p"): + self = .saslInitialResponse + case UInt8(ascii: "p"): + self = .saslResponse + case UInt8(ascii: "S"): + self = .sync + case UInt8(ascii: "X"): + self = .terminate + default: + return nil + } + } + + var rawValue: UInt8 { switch self { case .bind: return UInt8(ascii: "B") @@ -112,11 +140,11 @@ extension PSQLFrontendMessage { } func encode(data message: PSQLFrontendMessage, out buffer: inout ByteBuffer) throws { - struct EmptyPayload: PayloadEncodable { + struct EmptyPayload: PSQLMessagePayloadEncodable { func encode(into buffer: inout ByteBuffer) {} } - func encode(_ payload: Payload, into buffer: inout ByteBuffer) { + func encode(_ payload: Payload, into buffer: inout ByteBuffer) { let startIndex = buffer.writerIndex buffer.writeInteger(Int32(0)) // placeholder for length payload.encode(into: &buffer) @@ -126,7 +154,7 @@ extension PSQLFrontendMessage { switch message { case .bind(let bind): - buffer.writeInteger(message.id.byte) + buffer.writeInteger(message.id.rawValue) let startIndex = buffer.writerIndex buffer.writeInteger(Int32(0)) // placeholder for length try bind.encode(into: &buffer, using: self.jsonEncoder) @@ -177,6 +205,6 @@ extension PSQLFrontendMessage { } } -protocol PSQLFrontendMessagePayloadEncodable { +protocol PSQLMessagePayloadEncodable { func encode(into buffer: inout ByteBuffer) } diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 43e7e7cf..1f300342 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -12,7 +12,7 @@ class BindTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 37) - XCTAssertEqual(PSQLFrontendMessage.ID.bind.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index c75fe78a..b7734ebf 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -11,7 +11,7 @@ class CloseTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 12) - XCTAssertEqual(PSQLFrontendMessage.ID.close.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) @@ -25,7 +25,7 @@ class CloseTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 7) - XCTAssertEqual(PSQLFrontendMessage.ID.close.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index 777e0769..5ce6d2f0 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -11,7 +11,7 @@ class DescribeTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 12) - XCTAssertEqual(PSQLFrontendMessage.ID.describe.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) @@ -25,7 +25,7 @@ class DescribeTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 7) - XCTAssertEqual(PSQLFrontendMessage.ID.describe.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 177e5fcb..01093060 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -11,7 +11,7 @@ class ExecuteTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) - XCTAssertEqual(PSQLFrontendMessage.ID.execute.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual(0, byteBuffer.readInteger(as: Int32.self)) diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 146fc57f..1239633d 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -22,7 +22,7 @@ class ParseTests: XCTestCase { // + 1 query () XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.parse.byte) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.parse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName) XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query) diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 75ab3a85..0e8e2920 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -14,7 +14,7 @@ class PasswordTests: XCTestCase { let expectedLength = 41 // 1 (id) + 4 (length) + 35 (string) + 1 (null termination) XCTAssertEqual(byteBuffer.readableBytes, expectedLength) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.password.byte) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.password.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(expectedLength - 1)) // length XCTAssertEqual(byteBuffer.readNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") } diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index c846c260..00a601e4 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -21,7 +21,7 @@ class SASLInitialResponseTests: XCTestCase { // + 8 initialData XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.byte) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(sasl.initialData.count)) @@ -46,7 +46,7 @@ class SASLInitialResponseTests: XCTestCase { // + 0 initialData XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.byte) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index e4556ac2..6f117105 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -14,7 +14,7 @@ class SASLResponseTests: XCTestCase { let length: Int = 1 + 4 + (sasl.data.count) XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslResponse.byte) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readBytes(length: sasl.data.count), sasl.data) XCTAssertEqual(byteBuffer.readableBytes, 0) @@ -30,7 +30,7 @@ class SASLResponseTests: XCTestCase { let length: Int = 1 + 4 XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslResponse.byte) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift index 69fa1374..55c4c4d1 100644 --- a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -7,17 +7,17 @@ class PSQLFrontendMessageTests: XCTestCase { // MARK: ID func testMessageIDs() { - XCTAssertEqual(PSQLFrontendMessage.ID.bind.byte, UInt8(ascii: "B")) - XCTAssertEqual(PSQLFrontendMessage.ID.close.byte, UInt8(ascii: "C")) - XCTAssertEqual(PSQLFrontendMessage.ID.describe.byte, UInt8(ascii: "D")) - XCTAssertEqual(PSQLFrontendMessage.ID.execute.byte, UInt8(ascii: "E")) - XCTAssertEqual(PSQLFrontendMessage.ID.flush.byte, UInt8(ascii: "H")) - XCTAssertEqual(PSQLFrontendMessage.ID.parse.byte, UInt8(ascii: "P")) - XCTAssertEqual(PSQLFrontendMessage.ID.password.byte, UInt8(ascii: "p")) - XCTAssertEqual(PSQLFrontendMessage.ID.saslInitialResponse.byte, UInt8(ascii: "p")) - XCTAssertEqual(PSQLFrontendMessage.ID.saslResponse.byte, UInt8(ascii: "p")) - XCTAssertEqual(PSQLFrontendMessage.ID.sync.byte, UInt8(ascii: "S")) - XCTAssertEqual(PSQLFrontendMessage.ID.terminate.byte, UInt8(ascii: "X")) + XCTAssertEqual(PSQLFrontendMessage.ID.bind.rawValue, UInt8(ascii: "B")) + XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, UInt8(ascii: "C")) + XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, UInt8(ascii: "D")) + XCTAssertEqual(PSQLFrontendMessage.ID.execute.rawValue, UInt8(ascii: "E")) + XCTAssertEqual(PSQLFrontendMessage.ID.flush.rawValue, UInt8(ascii: "H")) + XCTAssertEqual(PSQLFrontendMessage.ID.parse.rawValue, UInt8(ascii: "P")) + XCTAssertEqual(PSQLFrontendMessage.ID.password.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PSQLFrontendMessage.ID.saslInitialResponse.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PSQLFrontendMessage.ID.saslResponse.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PSQLFrontendMessage.ID.sync.rawValue, UInt8(ascii: "S")) + XCTAssertEqual(PSQLFrontendMessage.ID.terminate.rawValue, UInt8(ascii: "X")) } // MARK: Encoder @@ -28,7 +28,7 @@ class PSQLFrontendMessageTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: .flush, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 5) - XCTAssertEqual(PSQLFrontendMessage.ID.flush.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } @@ -38,7 +38,7 @@ class PSQLFrontendMessageTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: .sync, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 5) - XCTAssertEqual(PSQLFrontendMessage.ID.sync.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.sync.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } @@ -48,7 +48,7 @@ class PSQLFrontendMessageTests: XCTestCase { XCTAssertNoThrow(try encoder.encode(data: .terminate, out: &byteBuffer)) XCTAssertEqual(byteBuffer.readableBytes, 5) - XCTAssertEqual(PSQLFrontendMessage.ID.terminate.byte, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PSQLFrontendMessage.ID.terminate.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } From ce57b02b76779aa99b61594acd99f519100b32bd Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 18 Sep 2021 15:53:07 +0200 Subject: [PATCH 017/292] Add RowStreamStateMachine (#176) ### Motivation To best support back-pressure, we extract the necessary state machine into a dedicated `RowStreamStateMachine`. ### Changes - Add RowStreamStateMachine --- .../RowStreamStateMachine.swift | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift diff --git a/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift new file mode 100644 index 00000000..165ba4f3 --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift @@ -0,0 +1,163 @@ +import NIOCore + +/// A sub state for receiving data rows. Stores whether the consumer has either signaled demand and whether the +/// channel has issued `read` events. +/// +/// This should be used as a SubStateMachine in QuerySubStateMachines. +struct RowStreamStateMachine { + + enum Action { + case read + case wait + } + + private enum State { + /// The state machines expects further writes to `channelRead`. The writes are appended to the buffer. + case waitingForRows(CircularBuffer) + /// The state machines expects a call to `demandMoreResponseBodyParts` or `read`. The buffer is + /// empty. It is preserved for performance reasons. + case waitingForReadOrDemand(CircularBuffer) + /// The state machines expects a call to `read`. The buffer is empty. It is preserved for performance reasons. + case waitingForRead(CircularBuffer) + /// The state machines expects a call to `demandMoreResponseBodyParts`. The buffer is empty. It is + /// preserved for performance reasons. + case waitingForDemand(CircularBuffer) + + case modifying + } + + private var state: State + + init() { + self.state = .waitingForRows(CircularBuffer(initialCapacity: 32)) + } + + mutating func receivedRow(_ newRow: PSQLBackendMessage.DataRow) { + switch self.state { + case .waitingForRows(var buffer): + self.state = .modifying + buffer.append(newRow) + self.state = .waitingForRows(buffer) + + // For all the following cases, please note: + // Normally these code paths should never be hit. However there is one way to trigger + // this: + // + // If the server decides to close a connection, NIO will forward all outstanding + // `channelRead`s without waiting for a next `context.read` call. For this reason we might + // receive new rows, when we don't expect them here. + case .waitingForRead(var buffer): + self.state = .modifying + buffer.append(newRow) + self.state = .waitingForRead(buffer) + + case .waitingForDemand(var buffer): + self.state = .modifying + buffer.append(newRow) + self.state = .waitingForDemand(buffer) + + case .waitingForReadOrDemand(var buffer): + self.state = .modifying + buffer.append(newRow) + self.state = .waitingForReadOrDemand(buffer) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func channelReadComplete() -> CircularBuffer? { + switch self.state { + case .waitingForRows(let buffer): + if buffer.isEmpty { + self.state = .waitingForRead(buffer) + return nil + } else { + var newBuffer = buffer + newBuffer.removeAll(keepingCapacity: true) + self.state = .waitingForReadOrDemand(newBuffer) + return buffer + } + + case .waitingForRead, + .waitingForDemand, + .waitingForReadOrDemand: + preconditionFailure("How can we receive a body part, after a channelReadComplete, but no read has been forwarded yet. Invalid state: \(self.state)") + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func demandMoreResponseBodyParts() -> Action { + switch self.state { + case .waitingForDemand(let buffer): + self.state = .waitingForRows(buffer) + return .read + + case .waitingForReadOrDemand(let buffer): + self.state = .waitingForRead(buffer) + return .wait + + case .waitingForRead: + // If we are `.waitingForRead`, no action needs to be taken. Demand has already been + // signaled. Once we receive the next `read`, we will forward it, right away + return .wait + + case .waitingForRows: + // If we are `.waitingForRows`, no action needs to be taken. As soon as we receive + // the next `channelReadComplete` we will forward all buffered data + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func read() -> Action { + switch self.state { + case .waitingForRows: + // This should never happen. But we don't want to precondition this behavior. Let's just + // pass the read event on + return .read + + case .waitingForReadOrDemand(let buffer): + self.state = .waitingForDemand(buffer) + return .wait + + case .waitingForRead(let buffer): + self.state = .waitingForRows(buffer) + return .read + + case .waitingForDemand: + // we have already received a read event. We will issue it as soon as we received demand + // from the consumer + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func end() -> CircularBuffer { + switch self.state { + case .waitingForRows(let buffer): + return buffer + + case .waitingForReadOrDemand(let buffer), + .waitingForRead(let buffer), + .waitingForDemand(let buffer): + + // Normally this code path should never be hit. However there is one way to trigger + // this: + // + // If the server decides to close a connection, NIO will forward all outstanding + // `channelRead`s without waiting for a next `context.read` call. For this reason we might + // receive a call to `end()`, when we don't expect it here. + return buffer + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } +} From 8b13752d8d6d7a02e1cb623e07b42cd1398bfd57 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 18 Sep 2021 16:54:35 +0200 Subject: [PATCH 018/292] Add PSQLBackendMessageEncoder (#175) ### Motivation To test a `PSQLChannelHandler`, that uses an internal `NIOSingleStepByteToMessageDecoder`, in an `EmbeddedChannel` we need to writeInbound bytes. To make this easier, this PR introduces a `PSQLBackendMessageEncoder` as a test util ### Changes - Add `PSQLBackendMessageEncoder` - Use `PSQLBackendMessageEncoder` in authentication tests --- .../PSQLBackendMessageEncoder.swift | 272 ++++++++++++++++++ .../New/Messages/AuthenticationTests.swift | 34 +-- 2 files changed, 280 insertions(+), 26 deletions(-) create mode 100644 Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift new file mode 100644 index 00000000..ea5323ec --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -0,0 +1,272 @@ +import NIOCore +@testable import PostgresNIO + +struct PSQLBackendMessageEncoder: MessageToByteEncoder { + typealias OutboundIn = PSQLBackendMessage + + /// Called once there is data to encode. + /// + /// - parameters: + /// - data: The data to encode into a `ByteBuffer`. + /// - out: The `ByteBuffer` into which we want to encode. + func encode(data message: PSQLBackendMessage, out buffer: inout ByteBuffer) throws { + switch message { + case .authentication(let authentication): + self.encode(messageID: message.id, payload: authentication, into: &buffer) + + case .backendKeyData(let keyData): + self.encode(messageID: message.id, payload: keyData, into: &buffer) + + case .bindComplete, + .closeComplete, + .emptyQueryResponse, + .noData, + .parseComplete, + .portalSuspended: + self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) + + case .commandComplete(let string): + self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer) + + case .dataRow(let row): + self.encode(messageID: message.id, payload: row, into: &buffer) + + case .error(let errorResponse): + self.encode(messageID: message.id, payload: errorResponse, into: &buffer) + + case .notice(let noticeResponse): + self.encode(messageID: message.id, payload: noticeResponse, into: &buffer) + + case .notification(let notificationResponse): + self.encode(messageID: message.id, payload: notificationResponse, into: &buffer) + + case .parameterDescription(let description): + self.encode(messageID: message.id, payload: description, into: &buffer) + + case .parameterStatus(let status): + self.encode(messageID: message.id, payload: status, into: &buffer) + + case .readyForQuery(let transactionState): + self.encode(messageID: message.id, payload: transactionState, into: &buffer) + + case .rowDescription(let description): + self.encode(messageID: message.id, payload: description, into: &buffer) + + case .sslSupported: + buffer.writeInteger(UInt8(ascii: "S")) + + case .sslUnsupported: + buffer.writeInteger(UInt8(ascii: "N")) + } + } + + private struct EmptyPayload: PSQLMessagePayloadEncodable { + func encode(into buffer: inout ByteBuffer) {} + } + + private struct StringPayload: PSQLMessagePayloadEncodable { + var string: String + init(_ string: String) { self.string = string } + func encode(into buffer: inout ByteBuffer) { + buffer.writeNullTerminatedString(self.string) + } + } + + private func encode( + messageID: PSQLBackendMessage.ID, + payload: Payload, + into buffer: inout ByteBuffer) + { + buffer.writeBackendMessageID(messageID) + let startIndex = buffer.writerIndex + buffer.writeInteger(Int32(0)) // placeholder for length + payload.encode(into: &buffer) + let length = Int32(buffer.writerIndex - startIndex) + buffer.setInteger(length, at: startIndex) + } +} + +extension PSQLBackendMessage { + var id: ID { + switch self { + case .authentication: + return .authentication + case .backendKeyData: + return .backendKeyData + case .bindComplete: + return .bindComplete + case .closeComplete: + return .closeComplete + case .commandComplete: + return .commandComplete + case .dataRow: + return .dataRow + case .emptyQueryResponse: + return .emptyQueryResponse + case .error: + return .error + case .noData: + return .noData + case .notice: + return .noticeResponse + case .notification: + return .notificationResponse + case .parameterDescription: + return .parameterDescription + case .parameterStatus: + return .parameterStatus + case .parseComplete: + return .parseComplete + case .portalSuspended: + return .portalSuspended + case .readyForQuery: + return .readyForQuery + case .rowDescription: + return .rowDescription + case .sslSupported, + .sslUnsupported: + preconditionFailure("Message has no id.") + } + } +} + +extension PSQLBackendMessage.Authentication: PSQLMessagePayloadEncodable { + + public func encode(into buffer: inout ByteBuffer) { + switch self { + case .ok: + buffer.writeInteger(Int32(0)) + + case .kerberosV5: + buffer.writeInteger(Int32(2)) + + case .plaintext: + buffer.writeInteger(Int32(3)) + + case .md5(salt: let salt): + buffer.writeInteger(Int32(5)) + buffer.writeInteger(salt.0) + buffer.writeInteger(salt.1) + buffer.writeInteger(salt.2) + buffer.writeInteger(salt.3) + + case .scmCredential: + buffer.writeInteger(Int32(6)) + + case .gss: + buffer.writeInteger(Int32(7)) + + case .gssContinue(var data): + buffer.writeInteger(Int32(8)) + buffer.writeBuffer(&data) + + case .sspi: + buffer.writeInteger(Int32(9)) + + case .sasl(names: let names): + buffer.writeInteger(Int32(10)) + for name in names { + buffer.writeNullTerminatedString(name) + } + + case .saslContinue(data: var data): + buffer.writeInteger(Int32(11)) + buffer.writeBuffer(&data) + + case .saslFinal(data: var data): + buffer.writeInteger(Int32(12)) + buffer.writeBuffer(&data) + } + } + +} + +extension PSQLBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(self.processID) + buffer.writeInteger(self.secretKey) + } +} + +extension PSQLBackendMessage.DataRow: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int16(self.columns.count)) + + for column in self.columns { + switch column { + case .none: + buffer.writeInteger(-1, as: Int32.self) + case .some(var writable): + buffer.writeInteger(Int32(writable.readableBytes)) + buffer.writeBuffer(&writable) + } + } + } +} + +extension PSQLBackendMessage.ErrorResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + for (key, value) in self.fields { + buffer.writeInteger(key.rawValue, as: UInt8.self) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(0, as: UInt8.self) // signal done + } +} + +extension PSQLBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + for (key, value) in self.fields { + buffer.writeInteger(key.rawValue, as: UInt8.self) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(0, as: UInt8.self) // signal done + } +} + +extension PSQLBackendMessage.NotificationResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(self.backendPID) + buffer.writeNullTerminatedString(self.channel) + buffer.writeNullTerminatedString(self.payload) + } +} + +extension PSQLBackendMessage.ParameterDescription: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int16(self.dataTypes.count)) + + for dataType in self.dataTypes { + buffer.writeInteger(dataType.rawValue) + } + } +} + +extension PSQLBackendMessage.ParameterStatus: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeNullTerminatedString(self.parameter) + buffer.writeNullTerminatedString(self.value) + } +} + +extension PSQLBackendMessage.TransactionState: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(self.rawValue) + } +} + +extension PSQLBackendMessage.RowDescription: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int16(self.columns.count)) + + for column in self.columns { + buffer.writeNullTerminatedString(column.name) + buffer.writeInteger(column.tableOID) + buffer.writeInteger(column.columnAttributeNumber) + buffer.writeInteger(column.dataType.rawValue) + buffer.writeInteger(column.dataTypeSize) + buffer.writeInteger(column.dataTypeModifier) + buffer.writeInteger(column.format.rawValue) + } + } +} diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 9091b3cf..52e63b2e 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -8,54 +8,36 @@ class AuthenticationTests: XCTestCase { func testDecodeAuthentication() { var expected = [PSQLBackendMessage]() var buffer = ByteBuffer() + let encoder = PSQLBackendMessageEncoder() // add ok - buffer.writeBackendMessage(id: .authentication) { buffer in - buffer.writeInteger(Int32(0)) - } + XCTAssertNoThrow(try encoder.encode(data: .authentication(.ok), out: &buffer)) expected.append(.authentication(.ok)) // add kerberos - buffer.writeBackendMessage(id: .authentication) { buffer in - buffer.writeInteger(Int32(2)) - } + XCTAssertNoThrow(try encoder.encode(data: .authentication(.kerberosV5), out: &buffer)) expected.append(.authentication(.kerberosV5)) // add plaintext - buffer.writeBackendMessage(id: .authentication) { buffer in - buffer.writeInteger(Int32(3)) - } + XCTAssertNoThrow(try encoder.encode(data: .authentication(.plaintext), out: &buffer)) expected.append(.authentication(.plaintext)) // add md5 - buffer.writeBackendMessage(id: .authentication) { buffer in - buffer.writeInteger(Int32(5)) - buffer.writeInteger(UInt8(1)) - buffer.writeInteger(UInt8(2)) - buffer.writeInteger(UInt8(3)) - buffer.writeInteger(UInt8(4)) - } + XCTAssertNoThrow(try encoder.encode(data: .authentication(.md5(salt: (1, 2, 3, 4))), out: &buffer)) expected.append(.authentication(.md5(salt: (1, 2, 3, 4)))) // add scm credential - buffer.writeBackendMessage(id: .authentication) { buffer in - buffer.writeInteger(Int32(6)) - } + XCTAssertNoThrow(try encoder.encode(data: .authentication(.scmCredential), out: &buffer)) expected.append(.authentication(.scmCredential)) // add gss - buffer.writeBackendMessage(id: .authentication) { buffer in - buffer.writeInteger(Int32(7)) - } + XCTAssertNoThrow(try encoder.encode(data: .authentication(.gss), out: &buffer)) expected.append(.authentication(.gss)) // add sspi - buffer.writeBackendMessage(id: .authentication) { buffer in - buffer.writeInteger(Int32(9)) - } + XCTAssertNoThrow(try encoder.encode(data: .authentication(.sspi), out: &buffer)) expected.append(.authentication(.sspi)) - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) From 3c29758edfe2fe4b3cb69f344b01defc04f88b3e Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 18 Sep 2021 16:56:17 +0200 Subject: [PATCH 019/292] Extract PSQLRow (#177) ### Motivation In #135, I did some naming things just wrong. `PSQLRows` is a stupid name. We should name it `PSQLRowStream`. `PSQLRows.Row` is a stupid name for a single table row. We should name it `PSQLRow`. Transforming an incoming data row packet to a `[PSQLData]` array early is expensive and stupid. Let's not do this anymore. ### Changes - Extract `PSQLRows.Row` to `PSQLRows` (Got its own file). Stop the early `[PSQLData]` madness. - Rename `PSQLRows` to `PSQLRowStream` - Fix naming in integration tests to match `PSQLRowStream` --- .../PostgresConnection+Database.swift | 6 +- .../ConnectionStateMachine.swift | 10 +- .../ExtendedQueryStateMachine.swift | 21 ++-- .../PostgresNIO/New/PSQLChannelHandler.swift | 8 +- Sources/PostgresNIO/New/PSQLConnection.swift | 10 +- Sources/PostgresNIO/New/PSQLRow.swift | 66 ++++++++++ .../{PSQLRows.swift => PSQLRowStream.swift} | 66 ++-------- Sources/PostgresNIO/New/PSQLTask.swift | 6 +- .../PSQLIntegrationTests.swift | 116 +++++++++--------- .../ConnectionStateMachineTests.swift | 2 +- .../ExtendedQueryStateMachineTests.swift | 8 +- 11 files changed, 165 insertions(+), 154 deletions(-) create mode 100644 Sources/PostgresNIO/New/PSQLRow.swift rename Sources/PostgresNIO/New/{PSQLRows.swift => PSQLRowStream.swift} (71%) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift index 6ee6ddbf..e48ac9ff 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift @@ -78,12 +78,12 @@ internal enum PostgresCommands: PostgresRequest { } } -extension PSQLRows { +extension PSQLRowStream { func iterateRowsWithoutBackpressureOption(lookupTable: PostgresRow.LookupTable, onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { self.onRow { psqlRow in - let columns = psqlRow.data.map { psqlData in - PostgresMessage.DataRow.Column(value: psqlData.bytes) + let columns = psqlRow.data.columns.map { bytes in + PostgresMessage.DataRow.Column(value: bytes) } let row = PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 49168e97..dbeafa5d 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -92,12 +92,12 @@ struct ConnectionStateMachine { // --- streaming actions // actions if query has requested next row but we are waiting for backend - case forwardRow([PSQLData], to: EventLoopPromise) - case forwardCommandComplete(CircularBuffer<[PSQLData]>, commandTag: String, to: EventLoopPromise) + case forwardRow(PSQLBackendMessage.DataRow, to: EventLoopPromise) + case forwardCommandComplete(CircularBuffer, commandTag: String, to: EventLoopPromise) case forwardStreamError(PSQLError, to: EventLoopPromise, cleanupContext: CleanUpContext?) // actions if query has not asked for next row but are pushing the final bytes to it case forwardStreamErrorToCurrentQuery(PSQLError, read: Bool, cleanupContext: CleanUpContext?) - case forwardStreamCompletedToCurrentQuery(CircularBuffer<[PSQLData]>, commandTag: String, read: Bool) + case forwardStreamCompletedToCurrentQuery(CircularBuffer, commandTag: String, read: Bool) // Prepare statement actions case sendParseDescribeSync(name: String, query: String) @@ -1106,10 +1106,10 @@ extension ConnectionStateMachine { enum StateMachineStreamNextResult { /// the next row - case row([PSQLData]) + case row(PSQLBackendMessage.DataRow) /// the query has completed, all remaining rows and the command completion tag - case complete(CircularBuffer<[PSQLData]>, commandTag: String) + case complete(CircularBuffer, commandTag: String) } struct SendPrepareStatement { diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index f1ae086f..36b69f83 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -14,8 +14,8 @@ struct ExtendedQueryStateMachine { /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) - case bufferingRows([PSQLBackendMessage.RowDescription.Column], CircularBuffer<[PSQLData]>, readOnEmpty: Bool) - case waitingForNextRow([PSQLBackendMessage.RowDescription.Column], CircularBuffer<[PSQLData]>, EventLoopPromise) + case bufferingRows([PSQLBackendMessage.RowDescription.Column], CircularBuffer, readOnEmpty: Bool) + case waitingForNextRow([PSQLBackendMessage.RowDescription.Column], CircularBuffer, EventLoopPromise) case commandComplete(commandTag: String) case error(PSQLError) @@ -34,12 +34,12 @@ struct ExtendedQueryStateMachine { // --- streaming actions // actions if query has requested next row but we are waiting for backend - case forwardRow([PSQLData], to: EventLoopPromise) - case forwardCommandComplete(CircularBuffer<[PSQLData]>, commandTag: String, to: EventLoopPromise) + case forwardRow(PSQLBackendMessage.DataRow, to: EventLoopPromise) + case forwardCommandComplete(CircularBuffer, commandTag: String, to: EventLoopPromise) case forwardStreamError(PSQLError, to: EventLoopPromise) // actions if query has not asked for next row but are pushing the final bytes to it case forwardStreamErrorToCurrentQuery(PSQLError, read: Bool) - case forwardStreamCompletedToCurrentQuery(CircularBuffer<[PSQLData]>, commandTag: String, read: Bool) + case forwardStreamCompletedToCurrentQuery(CircularBuffer, commandTag: String, read: Bool) case read case wait @@ -170,10 +170,7 @@ struct ExtendedQueryStateMachine { } return self.avoidingStateMachineCoW { state -> Action in - let row = dataRow.columns.enumerated().map { (index, buffer) in - PSQLData(bytes: buffer, dataType: columns[index].dataType, format: columns[index].format) - } - buffer.append(row) + buffer.append(dataRow) state = .bufferingRows(columns, buffer, readOnEmpty: readOnEmpty) return .wait } @@ -187,12 +184,8 @@ struct ExtendedQueryStateMachine { return self.avoidingStateMachineCoW { state -> Action in precondition(buffer.isEmpty, "Expected the buffer to be empty") - let row = dataRow.columns.enumerated().map { (index, buffer) in - PSQLData(bytes: buffer, dataType: columns[index].dataType, format: columns[index].format) - } - state = .bufferingRows(columns, buffer, readOnEmpty: false) - return .forwardRow(row, to: promise) + return .forwardRow(dataRow, to: promise) } case .initialized, diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index b4606639..e0f71114 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -18,7 +18,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.logger.trace("Connection state changed", metadata: [.connectionState: "\(self.state)"]) } } - private var currentQuery: PSQLRows? + private var currentQuery: PSQLRowStream? private let authentificationConfiguration: PSQLConnection.Configuration.Authentication? private let configureSSLCallback: ((Channel) throws -> Void)? @@ -426,7 +426,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.run(action, with: context) return promise.futureResult } - let rows = PSQLRows( + let rows = PSQLRowStream( rowDescription: columns, queryContext: queryContext, eventLoop: context.channel.eventLoop, @@ -451,14 +451,14 @@ final class PSQLChannelHandler: ChannelDuplexHandler { context: ChannelHandlerContext) { let eventLoop = context.channel.eventLoop - let rows = PSQLRows( + let rows = PSQLRowStream( rowDescription: [], queryContext: queryContext, eventLoop: context.channel.eventLoop, cancel: { // ignore... }, next: { - let emptyBuffer = CircularBuffer<[PSQLData]>(initialCapacity: 0) + let emptyBuffer = CircularBuffer(initialCapacity: 0) return eventLoop.makeSucceededFuture(.complete(emptyBuffer, commandTag: commandTag)) }) queryContext.promise.succeed(rows) diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 4689ed1a..1523692a 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -121,17 +121,17 @@ final class PSQLConnection { // MARK: Query - func query(_ query: String, logger: Logger) -> EventLoopFuture { + func query(_ query: String, logger: Logger) -> EventLoopFuture { self.query(query, [], logger: logger) } - func query(_ query: String, _ bind: [PSQLEncodable], logger: Logger) -> EventLoopFuture { + func query(_ query: String, _ bind: [PSQLEncodable], logger: Logger) -> EventLoopFuture { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.connectionID)" guard bind.count <= Int(Int16.max) else { return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) } - let promise = self.channel.eventLoop.makePromise(of: PSQLRows.self) + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let context = ExtendedQueryContext( query: query, bind: bind, @@ -161,12 +161,12 @@ final class PSQLConnection { } func execute(_ preparedStatement: PSQLPreparedStatement, - _ bind: [PSQLEncodable], logger: Logger) -> EventLoopFuture + _ bind: [PSQLEncodable], logger: Logger) -> EventLoopFuture { guard bind.count <= Int(Int16.max) else { return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) } - let promise = self.channel.eventLoop.makePromise(of: PSQLRows.self) + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let context = ExtendedQueryContext( preparedStatement: preparedStatement, bind: bind, diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift new file mode 100644 index 00000000..c5efb53a --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -0,0 +1,66 @@ + +/// `PSQLRow` represents a single row that was received from the Postgres Server. +struct PSQLRow { + internal let lookupTable: [String: Int] + internal let data: PSQLBackendMessage.DataRow + + internal let columns: [PSQLBackendMessage.RowDescription.Column] + internal let jsonDecoder: PSQLJSONDecoder + + internal init(data: PSQLBackendMessage.DataRow, lookupTable: [String: Int], columns: [PSQLBackendMessage.RowDescription.Column], jsonDecoder: PSQLJSONDecoder) { + self.data = data + self.lookupTable = lookupTable + self.columns = columns + self.jsonDecoder = jsonDecoder + } + + /// Access the raw Postgres data in the n-th column + subscript(index: Int) -> PSQLData { + PSQLData(bytes: self.data.columns[index], dataType: self.columns[index].dataType, format: self.columns[index].format) + } + + // TBD: Should this be optional? + /// Access the raw Postgres data in the column indentified by name + subscript(column columnName: String) -> PSQLData? { + guard let index = self.lookupTable[columnName] else { + return nil + } + + return self[index] + } + + /// Access the data in the provided column and decode it into the target type. + /// + /// - Parameters: + /// - column: The column name to read the data from + /// - type: The type to decode the data into + /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. + /// - Returns: The decoded value of Type T. + func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + guard let index = self.lookupTable[column] else { + preconditionFailure("A column '\(column)' does not exist.") + } + + return try self.decode(column: index, as: type, file: file, line: line) + } + + /// Access the data in the provided column and decode it into the target type. + /// + /// - Parameters: + /// - column: The column index to read the data from + /// - type: The type to decode the data into + /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. + /// - Returns: The decoded value of Type T. + func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + let column = self.columns[index] + + let decodingContext = PSQLDecodingContext( + jsonDecoder: jsonDecoder, + columnName: column.name, + columnIndex: index, + file: file, + line: line) + + return try self[index].decode(as: T.self, context: decodingContext) + } +} diff --git a/Sources/PostgresNIO/New/PSQLRows.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift similarity index 71% rename from Sources/PostgresNIO/New/PSQLRows.swift rename to Sources/PostgresNIO/New/PSQLRowStream.swift index a8632e7c..0f28a527 100644 --- a/Sources/PostgresNIO/New/PSQLRows.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -1,14 +1,14 @@ import NIOCore import Logging -final class PSQLRows { +final class PSQLRowStream { let eventLoop: EventLoop let logger: Logger private enum UpstreamState { case streaming(next: () -> EventLoopFuture, cancel: () -> ()) - case finished(remaining: CircularBuffer<[PSQLData]>, commandTag: String) + case finished(remaining: CircularBuffer, commandTag: String) case failure(Error) case consumed(Result) } @@ -46,7 +46,7 @@ final class PSQLRows { self.lookupTable = lookup } - func next() -> EventLoopFuture { + func next() -> EventLoopFuture { guard self.eventLoop.inEventLoop else { return self.eventLoop.flatSubmit { self.next() @@ -57,15 +57,15 @@ final class PSQLRows { switch self.upstreamState { case .streaming(let upstreamNext, _): - return upstreamNext().map { payload -> Row? in + return upstreamNext().map { payload -> PSQLRow? in self.downstreamState = .consuming switch payload { case .row(let data): - return Row(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + return PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) case .complete(var buffer, let commandTag): if let data = buffer.popFirst() { self.upstreamState = .finished(remaining: buffer, commandTag: commandTag) - return Row(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + return PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) } self.upstreamState = .consumed(.success(commandTag)) @@ -82,7 +82,7 @@ final class PSQLRows { self.downstreamState = .consuming if let data = buffer.popFirst() { self.upstreamState = .finished(remaining: buffer, commandTag: commandTag) - let row = Row(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) return self.eventLoop.makeSucceededFuture(row) } @@ -104,7 +104,7 @@ final class PSQLRows { ]) } - internal func finalForward(_ finalForward: Result<(CircularBuffer<[PSQLData]>, commandTag: String), PSQLError>?) { + internal func finalForward(_ finalForward: Result<(CircularBuffer, commandTag: String), PSQLError>?) { switch finalForward { case .some(.success((let buffer, commandTag: let commandTag))): guard case .streaming = self.upstreamState else { @@ -146,56 +146,8 @@ final class PSQLRows { } return commandTag } - - struct Row { - let lookupTable: [String: Int] - let data: [PSQLData] - let columns: [PSQLBackendMessage.RowDescription.Column] - let jsonDecoder: PSQLJSONDecoder - - init(data: [PSQLData], lookupTable: [String: Int], columns: [PSQLBackendMessage.RowDescription.Column], jsonDecoder: PSQLJSONDecoder) { - self.data = data - self.lookupTable = lookupTable - self.columns = columns - self.jsonDecoder = jsonDecoder - } - - subscript(index: Int) -> PSQLData { - self.data[index] - } - - // TBD: Should this be optional? - subscript(column columnName: String) -> PSQLData? { - guard let index = self.lookupTable[columnName] else { - return nil - } - - return self[index] - } - func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { - guard let index = self.lookupTable[column] else { - preconditionFailure("A column '\(column)' does not exist.") - } - - return try self.decode(column: index, as: type, file: file, line: line) - } - - func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { - let column = self.columns[index] - - let decodingContext = PSQLDecodingContext( - jsonDecoder: jsonDecoder, - columnName: column.name, - columnIndex: index, - file: file, - line: line) - - return try self[index].decode(as: T.self, context: decodingContext) - } - } - - func onRow(_ onRow: @escaping (Row) -> EventLoopFuture) -> EventLoopFuture { + func onRow(_ onRow: @escaping (PSQLRow) -> EventLoopFuture) -> EventLoopFuture { let promise = self.eventLoop.makePromise(of: Void.self) func consumeNext(promise: EventLoopPromise) { diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 895201b8..af3e8ee4 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -29,13 +29,13 @@ final class ExtendedQueryContext { let logger: Logger let jsonDecoder: PSQLJSONDecoder - let promise: EventLoopPromise + let promise: EventLoopPromise init(query: String, bind: [PSQLEncodable], logger: Logger, jsonDecoder: PSQLJSONDecoder, - promise: EventLoopPromise) + promise: EventLoopPromise) { self.query = .unnamed(query) self.bind = bind @@ -48,7 +48,7 @@ final class ExtendedQueryContext { bind: [PSQLEncodable], logger: Logger, jsonDecoder: PSQLJSONDecoder, - promise: EventLoopPromise) + promise: EventLoopPromise) { self.query = .preparedStatement( name: preparedStatement.name, diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index d3b25ef7..c7112a5b 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -54,14 +54,14 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query("SELECT version()", logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query("SELECT version()", logger: .psqlTest).wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) var version: String? XCTAssertNoThrow(version = try row?.decode(column: 0, as: String.self)) XCTAssertEqual(version?.contains("PostgreSQL"), true) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } func testQuery10kItems() { @@ -73,12 +73,12 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest).wait()) + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest).wait()) var expected: Int64 = 1 - XCTAssertNoThrow(try rows?.onRow { row in + XCTAssertNoThrow(try stream?.onRow { row in let promise = eventLoop.makePromise(of: Void.self) func workaround() { @@ -109,14 +109,14 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try conn?.close().wait()) } for _ in 0..<1_000 { - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query("SELECT version()", logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query("SELECT version()", logger: .psqlTest).wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) var version: String? XCTAssertNoThrow(version = try row?.decode(column: 0, as: String.self)) XCTAssertEqual(version?.contains("PostgreSQL"), true) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } } @@ -129,14 +129,14 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query("SELECT $1::TEXT as foo", ["hello"], logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query("SELECT $1::TEXT as foo", ["hello"], logger: .psqlTest).wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) var foo: String? XCTAssertNoThrow(foo = try row?.decode(column: 0, as: String.self)) XCTAssertEqual(foo, "hello") - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } func testDecodeIntegers() { @@ -148,8 +148,8 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query(""" + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query(""" SELECT 1::SMALLINT as smallint, -32767::SMALLINT as smallint_min, @@ -162,8 +162,8 @@ final class IntegrationTests: XCTestCase { 9223372036854775807::BIGINT as bigint_max """, logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) XCTAssertEqual(try row?.decode(column: "smallint", as: Int16.self), 1) XCTAssertEqual(try row?.decode(column: "smallint_min", as: Int16.self), -32_767) @@ -175,7 +175,7 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(try row?.decode(column: "bigint_min", as: Int64.self), -9_223_372_036_854_775_807) XCTAssertEqual(try row?.decode(column: "bigint_max", as: Int64.self), 9_223_372_036_854_775_807) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } func testEncodeAndDecodeIntArray() { @@ -187,15 +187,15 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? + var stream: PSQLRowStream? let array: [Int64] = [1, 2, 3] - XCTAssertNoThrow(rows = try conn?.query("SELECT $1::int8[] as array", [array], logger: .psqlTest).wait()) + XCTAssertNoThrow(stream = try conn?.query("SELECT $1::int8[] as array", [array], logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) XCTAssertEqual(try row?.decode(column: "array", as: [Int64].self), array) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } func testDecodeEmptyIntegerArray() { @@ -207,14 +207,14 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query("SELECT '{}'::int[] as array", logger: .psqlTest).wait()) + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query("SELECT '{}'::int[] as array", logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) XCTAssertEqual(try row?.decode(column: "array", as: [Int64].self), []) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } func testDoubleArraySerialization() { @@ -226,15 +226,15 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? + var stream: PSQLRowStream? let doubles: [Double] = [3.14, 42] - XCTAssertNoThrow(rows = try conn?.query("SELECT $1::double precision[] as doubles", [doubles], logger: .psqlTest).wait()) + XCTAssertNoThrow(stream = try conn?.query("SELECT $1::double precision[] as doubles", [doubles], logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) XCTAssertEqual(try row?.decode(column: "doubles", as: [Double].self), doubles) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } func testDecodeDates() { @@ -246,22 +246,22 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query(""" + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query(""" SELECT '2016-01-18 01:02:03 +0042'::DATE as date, '2016-01-18 01:02:03 +0042'::TIMESTAMP as timestamp, '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz """, logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) XCTAssertEqual(try row?.decode(column: "date", as: Date.self).description, "2016-01-18 00:00:00 +0000") XCTAssertEqual(try row?.decode(column: "timestamp", as: Date.self).description, "2016-01-18 01:02:03 +0000") XCTAssertEqual(try row?.decode(column: "timestamptz", as: Date.self).description, "2016-01-18 00:20:03 +0000") - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } func testDecodeUUID() { @@ -273,17 +273,17 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query(""" + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query(""" SELECT '2c68f645-9ca6-468b-b193-ee97f241c2f8'::UUID as uuid """, logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) XCTAssertEqual(try row?.decode(column: "uuid", as: UUID.self), UUID(uuidString: "2c68f645-9ca6-468b-b193-ee97f241c2f8")) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } func testRoundTripJSONB() { @@ -301,35 +301,35 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try conn?.close().wait()) } do { - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query(""" + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query(""" select $1::jsonb as jsonb """, [Object(foo: 1, bar: 2)], logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) var result: Object? XCTAssertNoThrow(result = try row?.decode(column: "jsonb", as: Object.self)) XCTAssertEqual(result?.foo, 1) XCTAssertEqual(result?.bar, 2) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } do { - var rows: PSQLRows? - XCTAssertNoThrow(rows = try conn?.query(""" + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query(""" select $1::json as json """, [Object(foo: 1, bar: 2)], logger: .psqlTest).wait()) - var row: PSQLRows.Row? - XCTAssertNoThrow(row = try rows?.next().wait()) + var row: PSQLRow? + XCTAssertNoThrow(row = try stream?.next().wait()) var result: Object? XCTAssertNoThrow(result = try row?.decode(column: "json", as: Object.self)) XCTAssertEqual(result?.foo, 1) XCTAssertEqual(result?.bar, 2) - XCTAssertNil(try rows?.next().wait()) + XCTAssertNil(try stream?.next().wait()) } } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index eb282444..e796c0f9 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -125,7 +125,7 @@ class ConnectionStateMachineTests: XCTestCase { let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) let jsonDecoder = JSONDecoder() - let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRows.self) + let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self) var state = ConnectionStateMachine() let extendedQueryContext = ExtendedQueryContext( diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 6cb48324..ea457bd5 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -10,7 +10,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest - let promise = EmbeddedEventLoop().makePromise(of: PSQLRows.self) + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "DELETE FROM table WHERE id=$0" let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, jsonDecoder: JSONDecoder(), promise: promise) @@ -28,7 +28,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest - let queryPromise = EmbeddedEventLoop().makePromise(of: PSQLRows.self) + let queryPromise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) queryPromise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "SELECT version()" let queryContext = ExtendedQueryContext(query: query, bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: queryPromise) @@ -56,7 +56,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let rowPromise = EmbeddedEventLoop().makePromise(of: StateMachineStreamNextResult.self) rowPromise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - XCTAssertEqual(state.consumeNextQueryRow(promise: rowPromise), .forwardRow([.init(bytes: rowContent, dataType: .text, format: .binary)], to: rowPromise)) + XCTAssertEqual(state.consumeNextQueryRow(promise: rowPromise), .forwardRow(.init(columns: [rowContent]), to: rowPromise)) XCTAssertEqual(state.commandCompletedReceived("SELECT 1"), .forwardStreamCompletedToCurrentQuery(CircularBuffer(), commandTag: "SELECT 1", read: true)) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) @@ -66,7 +66,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest - let promise = EmbeddedEventLoop().makePromise(of: PSQLRows.self) + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "DELETE FROM table WHERE id=$0" let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, jsonDecoder: JSONDecoder(), promise: promise) From 8c32013354899e1e3b527b08577143061b55681c Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sun, 19 Sep 2021 13:16:18 +0200 Subject: [PATCH 020/292] Add BufferedMessageEncoder, fix PSQLFrontendMessageEncoder (#179) ### Motivation - We want to communicate as efficiently as possible to the server. For this reason we use message pipelining within a query. - If we drop the current use of `MessageToByteHandler` and do the encoding in the `PSQLChannelHandler`, we can write to a single ByteBuffer and pass it once to the pipeline with a flush ### Changes - Add `BufferedMessageEncoder` - Rename `PSQLFrontendMessage.Encoder` to `PSQLFrontendMessageEncoder` (got its own file) --- .../New/BufferedMessageEncoder.swift | 39 ++++++++ Sources/PostgresNIO/New/PSQLConnection.swift | 2 +- .../PostgresNIO/New/PSQLFrontendMessage.swift | 76 ---------------- .../New/PSQLFrontendMessageEncoder.swift | 88 +++++++++++++++++++ .../New/Extensions/PSQLCoding+TestUtils.swift | 2 +- .../New/Messages/BindTests.swift | 2 +- .../New/Messages/CancelTests.swift | 2 +- .../New/Messages/CloseTests.swift | 4 +- .../New/Messages/DescribeTests.swift | 4 +- .../New/Messages/ExecuteTests.swift | 2 +- .../New/Messages/ParseTests.swift | 2 +- .../New/Messages/PasswordTests.swift | 2 +- .../Messages/SASLInitialResponseTests.swift | 4 +- .../New/Messages/SASLResponseTests.swift | 4 +- .../New/Messages/SSLRequestTests.swift | 2 +- .../New/Messages/StartupTests.swift | 2 +- .../New/PSQLFrontendMessageTests.swift | 6 +- 17 files changed, 147 insertions(+), 96 deletions(-) create mode 100644 Sources/PostgresNIO/New/BufferedMessageEncoder.swift create mode 100644 Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift diff --git a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift new file mode 100644 index 00000000..9c02871e --- /dev/null +++ b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift @@ -0,0 +1,39 @@ +import NIOCore + +struct BufferedMessageEncoder { + private enum State { + case flushed + case writable + } + + private var buffer: ByteBuffer + private var state: State = .writable + private var encoder: Encoder + + init(buffer: ByteBuffer, encoder: Encoder) { + self.buffer = buffer + self.encoder = encoder + } + + mutating func encode(_ message: Encoder.OutboundIn) throws { + switch self.state { + case .flushed: + self.state = .writable + self.buffer.clear() + + case .writable: + break + } + + try self.encoder.encode(data: message, out: &self.buffer) + } + + mutating func flush() -> ByteBuffer? { + guard self.buffer.readableBytes > 0 else { + return nil + } + + self.state = .flushed + return self.buffer + } +} diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 1523692a..54c58aee 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -231,7 +231,7 @@ final class PSQLConnection { return channel.pipeline.addHandlers([ decoder, - MessageToByteHandler(PSQLFrontendMessage.Encoder(jsonEncoder: configuration.coders.jsonEncoder)), + MessageToByteHandler(PSQLFrontendMessageEncoder(jsonEncoder: configuration.coders.jsonEncoder)), PSQLChannelHandler( authentification: configuration.authentication, logger: logger, diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift b/Sources/PostgresNIO/New/PSQLFrontendMessage.swift index 5800b2da..56e94ff0 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessage.swift @@ -129,82 +129,6 @@ extension PSQLFrontendMessage { } } -extension PSQLFrontendMessage { - struct Encoder: MessageToByteEncoder { - typealias OutboundIn = PSQLFrontendMessage - - let jsonEncoder: PSQLJSONEncoder - - init(jsonEncoder: PSQLJSONEncoder) { - self.jsonEncoder = jsonEncoder - } - - func encode(data message: PSQLFrontendMessage, out buffer: inout ByteBuffer) throws { - struct EmptyPayload: PSQLMessagePayloadEncodable { - func encode(into buffer: inout ByteBuffer) {} - } - - func encode(_ payload: Payload, into buffer: inout ByteBuffer) { - let startIndex = buffer.writerIndex - buffer.writeInteger(Int32(0)) // placeholder for length - payload.encode(into: &buffer) - let length = Int32(buffer.writerIndex - startIndex) - buffer.setInteger(length, at: startIndex) - } - - switch message { - case .bind(let bind): - buffer.writeInteger(message.id.rawValue) - let startIndex = buffer.writerIndex - buffer.writeInteger(Int32(0)) // placeholder for length - try bind.encode(into: &buffer, using: self.jsonEncoder) - let length = Int32(buffer.writerIndex - startIndex) - buffer.setInteger(length, at: startIndex) - - case .cancel(let cancel): - // cancel requests don't have an identifier - encode(cancel, into: &buffer) - case .close(let close): - buffer.writeFrontendMessageID(message.id) - encode(close, into: &buffer) - case .describe(let describe): - buffer.writeFrontendMessageID(message.id) - encode(describe, into: &buffer) - case .execute(let execute): - buffer.writeFrontendMessageID(message.id) - encode(execute, into: &buffer) - case .flush: - buffer.writeFrontendMessageID(message.id) - encode(EmptyPayload(), into: &buffer) - case .parse(let parse): - buffer.writeFrontendMessageID(message.id) - encode(parse, into: &buffer) - case .password(let password): - buffer.writeFrontendMessageID(message.id) - encode(password, into: &buffer) - case .saslInitialResponse(let saslInitialResponse): - buffer.writeFrontendMessageID(message.id) - encode(saslInitialResponse, into: &buffer) - case .saslResponse(let saslResponse): - buffer.writeFrontendMessageID(message.id) - encode(saslResponse, into: &buffer) - case .sslRequest(let request): - // sslRequests don't have an identifier - encode(request, into: &buffer) - case .startup(let startup): - // startup requests don't have an identifier - encode(startup, into: &buffer) - case .sync: - buffer.writeFrontendMessageID(message.id) - encode(EmptyPayload(), into: &buffer) - case .terminate: - buffer.writeFrontendMessageID(message.id) - encode(EmptyPayload(), into: &buffer) - } - } - } -} - protocol PSQLMessagePayloadEncodable { func encode(into buffer: inout ByteBuffer) } diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift new file mode 100644 index 00000000..0a998285 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift @@ -0,0 +1,88 @@ + +struct PSQLFrontendMessageEncoder: MessageToByteEncoder { + typealias OutboundIn = PSQLFrontendMessage + + let jsonEncoder: PSQLJSONEncoder + + init(jsonEncoder: PSQLJSONEncoder) { + self.jsonEncoder = jsonEncoder + } + + func encode(data message: PSQLFrontendMessage, out buffer: inout ByteBuffer) throws { + switch message { + case .bind(let bind): + buffer.writeInteger(message.id.rawValue) + let startIndex = buffer.writerIndex + buffer.writeInteger(Int32(0)) // placeholder for length + try bind.encode(into: &buffer, using: self.jsonEncoder) + let length = Int32(buffer.writerIndex - startIndex) + buffer.setInteger(length, at: startIndex) + + case .cancel(let cancel): + // cancel requests don't have an identifier + self.encode(payload: cancel, into: &buffer) + + case .close(let close): + self.encode(messageID: message.id, payload: close, into: &buffer) + + case .describe(let describe): + self.encode(messageID: message.id, payload: describe, into: &buffer) + + case .execute(let execute): + self.encode(messageID: message.id, payload: execute, into: &buffer) + + case .flush: + self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) + + case .parse(let parse): + self.encode(messageID: message.id, payload: parse, into: &buffer) + + case .password(let password): + self.encode(messageID: message.id, payload: password, into: &buffer) + + case .saslInitialResponse(let saslInitialResponse): + self.encode(messageID: message.id, payload: saslInitialResponse, into: &buffer) + + case .saslResponse(let saslResponse): + self.encode(messageID: message.id, payload: saslResponse, into: &buffer) + + case .sslRequest(let request): + // sslRequests don't have an identifier + self.encode(payload: request, into: &buffer) + + case .startup(let startup): + // startup requests don't have an identifier + self.encode(payload: startup, into: &buffer) + + case .sync: + self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) + + case .terminate: + self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) + } + } + + private struct EmptyPayload: PSQLMessagePayloadEncodable { + func encode(into buffer: inout ByteBuffer) {} + } + + private func encode( + messageID: PSQLFrontendMessage.ID, + payload: Payload, + into buffer: inout ByteBuffer) + { + buffer.writeFrontendMessageID(messageID) + self.encode(payload: payload, into: &buffer) + } + + private func encode( + payload: Payload, + into buffer: inout ByteBuffer) + { + let startIndex = buffer.writerIndex + buffer.writeInteger(Int32(0)) // placeholder for length + payload.encode(into: &buffer) + let length = Int32(buffer.writerIndex - startIndex) + buffer.setInteger(length, at: startIndex) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift index 569a9ea6..b6f2e1d1 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift @@ -1,7 +1,7 @@ @testable import PostgresNIO import Foundation -extension PSQLFrontendMessage.Encoder { +extension PSQLFrontendMessageEncoder { static var forTests: Self { Self(jsonEncoder: JSONEncoder()) } diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 1f300342..7a688d41 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -5,7 +5,7 @@ import NIOCore class BindTests: XCTestCase { func testEncodeBind() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let bind = PSQLFrontendMessage.Bind(portalName: "", preparedStatementName: "", parameters: ["Hello", "World"]) let message = PSQLFrontendMessage.bind(bind) diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift index 80ac98d0..551e5769 100644 --- a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -5,7 +5,7 @@ import NIOCore class CancelTests: XCTestCase { func testEncodeCancel() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let cancel = PSQLFrontendMessage.Cancel(processID: 1234, secretKey: 4567) let message = PSQLFrontendMessage.cancel(cancel) diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index b7734ebf..4df15896 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -5,7 +5,7 @@ import NIOCore class CloseTests: XCTestCase { func testEncodeClosePortal() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.close(.portal("Hello")) XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) @@ -19,7 +19,7 @@ class CloseTests: XCTestCase { } func testEncodeCloseUnnamedStatement() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.close(.preparedStatement("")) XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index 5ce6d2f0..87f7d09b 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -5,7 +5,7 @@ import NIOCore class DescribeTests: XCTestCase { func testEncodeDescribePortal() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.describe(.portal("Hello")) XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) @@ -19,7 +19,7 @@ class DescribeTests: XCTestCase { } func testEncodeDescribeUnnamedStatement() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.describe(.preparedStatement("")) XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 01093060..3ce8d63d 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -5,7 +5,7 @@ import NIOCore class ExecuteTests: XCTestCase { func testEncodeExecute() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.execute(.init(portalName: "", maxNumberOfRows: 0)) XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 1239633d..c147b749 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -5,7 +5,7 @@ import NIOCore class ParseTests: XCTestCase { func testEncode() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let parse = PSQLFrontendMessage.Parse( preparedStatementName: "test", diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 0e8e2920..73c464f3 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -5,7 +5,7 @@ import NIOCore class PasswordTests: XCTestCase { func testEncodePassword() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() // md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4 let message = PSQLFrontendMessage.password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a")) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index 00a601e4..af2459ac 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -5,7 +5,7 @@ import NIOCore class SASLInitialResponseTests: XCTestCase { func testEncodeWithData() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let sasl = PSQLFrontendMessage.SASLInitialResponse( saslMechanism: "hello", initialData: [0, 1, 2, 3, 4, 5, 6, 7]) @@ -30,7 +30,7 @@ class SASLInitialResponseTests: XCTestCase { } func testEncodeWithoutData() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let sasl = PSQLFrontendMessage.SASLInitialResponse( saslMechanism: "hello", initialData: []) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index 6f117105..aeb4448a 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -5,7 +5,7 @@ import NIOCore class SASLResponseTests: XCTestCase { func testEncodeWithData() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let sasl = PSQLFrontendMessage.SASLResponse(data: [0, 1, 2, 3, 4, 5, 6, 7]) let message = PSQLFrontendMessage.saslResponse(sasl) @@ -21,7 +21,7 @@ class SASLResponseTests: XCTestCase { } func testEncodeWithoutData() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let sasl = PSQLFrontendMessage.SASLResponse(data: []) let message = PSQLFrontendMessage.saslResponse(sasl) diff --git a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift index 7f8e57f4..bf7cac41 100644 --- a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift @@ -5,7 +5,7 @@ import NIOCore class SSLRequestTests: XCTestCase { func testSSLRequest() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let request = PSQLFrontendMessage.SSLRequest() let message = PSQLFrontendMessage.sslRequest(request) diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 3a386bd3..1224aede 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -5,7 +5,7 @@ import NIOCore class StartupTests: XCTestCase { func testStartupMessage() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() let replicationValues: [PSQLFrontendMessage.Startup.Parameters.Replication] = [ diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift index 55c4c4d1..83b41392 100644 --- a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -23,7 +23,7 @@ class PSQLFrontendMessageTests: XCTestCase { // MARK: Encoder func testEncodeFlush() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() XCTAssertNoThrow(try encoder.encode(data: .flush, out: &byteBuffer)) @@ -33,7 +33,7 @@ class PSQLFrontendMessageTests: XCTestCase { } func testEncodeSync() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() XCTAssertNoThrow(try encoder.encode(data: .sync, out: &byteBuffer)) @@ -43,7 +43,7 @@ class PSQLFrontendMessageTests: XCTestCase { } func testEncodeTerminate() { - let encoder = PSQLFrontendMessage.Encoder.forTests + let encoder = PSQLFrontendMessageEncoder.forTests var byteBuffer = ByteBuffer() XCTAssertNoThrow(try encoder.encode(data: .terminate, out: &byteBuffer)) From 419c20b40fcfb7416de1ba8238b4a86ccaafc84c Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sun, 19 Sep 2021 21:20:04 +0200 Subject: [PATCH 021/292] Add PSQLFrontendMessageDecoder (#178) ### Motivation In tests we would like to use `EmbeddedChannel`, if we decode and encode in the `PSQLChannelHandler` using the `NIOSingleStepByteToMessageDecoder`, we need to decode messages the handler wrote to the channel from bytes. ### Changes - Add `PSQLFrontendMessageDecoder` - Add `ReverseByteToMessageHandler` --- .../New/PSQLBackendMessageDecoder.swift | 6 +- .../PSQLFrontendMessageDecoder.swift | 172 ++++++++++++++++++ .../Extensions/ReverseChannelDecoder.swift | 36 ++++ 3 files changed, 211 insertions(+), 3 deletions(-) create mode 100644 Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift create mode 100644 Tests/PostgresNIOTests/New/Extensions/ReverseChannelDecoder.swift diff --git a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift index 58f5c460..edf386df 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift @@ -69,7 +69,7 @@ struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { return try PSQLBackendMessage.decode(from: &slice, for: messageID) } catch let error as PSQLPartialDecodingError { - throw PSQLDecodingError.withPartialError(error, messageID: messageID, messageBytes: completeMessageBuffer) + throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer) } catch { preconditionFailure("Expected to only see `PartialDecodingError`s here.") } @@ -106,14 +106,14 @@ struct PSQLDecodingError: Error { static func withPartialError( _ partialError: PSQLPartialDecodingError, - messageID: PSQLBackendMessage.ID, + messageID: UInt8, messageBytes: ByteBuffer) -> Self { var byteBuffer = messageBytes let data = byteBuffer.readData(length: byteBuffer.readableBytes)! return PSQLDecodingError( - messageID: messageID.rawValue, + messageID: messageID, payload: data.base64EncodedString(), description: partialError.description, file: partialError.file, diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift new file mode 100644 index 00000000..79e56507 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -0,0 +1,172 @@ +@testable import PostgresNIO + +struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { + typealias InboundOut = PSQLFrontendMessage + + private(set) var isInStartup: Bool + + init() { + self.isInStartup = true + } + + mutating func decode(buffer: inout ByteBuffer) throws -> PSQLFrontendMessage? { + // make sure we have at least one byte to read + guard buffer.readableBytes > 0 else { + return nil + } + + if self.isInStartup { + guard let length = buffer.getInteger(at: buffer.readerIndex, as: UInt32.self) else { + return nil + } + + guard var messageSlice = buffer.getSlice(at: buffer.readerIndex &+ 4, length: Int(length)) else { + return nil + } + buffer.moveReaderIndex(forwardBy: 4 &+ Int(length)) + let finalIndex = buffer.readerIndex + + guard let code = buffer.readInteger(as: UInt32.self) else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: UInt32.self) + } + + switch code { + case 80877103: + self.isInStartup = true + return .sslRequest(.init()) + + case 196608: + var user: String? + var database: String? + var options: String? + + while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex { + let value = messageSlice.readNullTerminatedString() + + switch name { + case "user": + user = value + + case "database": + database = value + + case "options": + options = value + + default: + break + } + } + + let parameters = PSQLFrontendMessage.Startup.Parameters( + user: user!, + database: database, + options: options, + replication: .false + ) + + let startup = PSQLFrontendMessage.Startup( + protocolVersion: 0x00_03_00_00, + parameters: parameters + ) + + precondition(buffer.readerIndex == finalIndex) + self.isInStartup = false + + return .startup(startup) + + default: + throw PSQLDecodingError.unknownStartupCodeReceived(code: code, messageBytes: messageSlice) + } + } + + // all other packages have an Int32 after the identifier that determines their length. + // do we have enough bytes for that? + guard let idByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), + let length = buffer.getInteger(at: buffer.readerIndex + 1, as: Int32.self) else { + return nil + } + + // At this point we are sure, that we have enough bytes to decode the next message. + // 1. Create a byteBuffer that represents exactly the next message. This can be force + // unwrapped, since it was verified that enough bytes are available. + guard let completeMessageBuffer = buffer.readSlice(length: 1 + Int(length)) else { + return nil + } + + // 2. make sure we have a known message identifier + guard let messageID = PSQLFrontendMessage.ID(rawValue: idByte) else { + throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer) + } + + // 3. decode the message + do { + // get a mutable byteBuffer copy + var slice = completeMessageBuffer + // move reader index forward by five bytes + slice.moveReaderIndex(forwardBy: 5) + + return try PSQLFrontendMessage.decode(from: &slice, for: messageID) + } catch let error as PSQLPartialDecodingError { + throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer) + } catch { + preconditionFailure("Expected to only see `PartialDecodingError`s here.") + } + } + + mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PSQLFrontendMessage? { + try self.decode(buffer: &buffer) + } +} + +extension PSQLFrontendMessage { + + static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PSQLFrontendMessage { + switch messageID { + case .bind: + preconditionFailure("TODO: Unimplemented") + case .close: + preconditionFailure("TODO: Unimplemented") + case .describe: + preconditionFailure("TODO: Unimplemented") + case .execute: + preconditionFailure("TODO: Unimplemented") + case .flush: + return .flush + case .parse: + preconditionFailure("TODO: Unimplemented") + case .password: + guard let password = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .password(.init(value: password)) + case .saslInitialResponse: + preconditionFailure("TODO: Unimplemented") + case .saslResponse: + preconditionFailure("TODO: Unimplemented") + case .sync: + return .sync + case .terminate: + return .terminate + } + } +} + +extension PSQLDecodingError { + static func unknownStartupCodeReceived( + code: UInt32, + messageBytes: ByteBuffer, + file: String = #file, + line: Int = #line) -> Self + { + var byteBuffer = messageBytes + let data = byteBuffer.readData(length: byteBuffer.readableBytes)! + + return PSQLDecodingError( + messageID: 0, + payload: data.base64EncodedString(), + description: "Received a startup code '\(code)'. There is no message associated with this code.", + file: file, + line: line) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/ReverseChannelDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/ReverseChannelDecoder.swift new file mode 100644 index 00000000..654a2546 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/ReverseChannelDecoder.swift @@ -0,0 +1,36 @@ +import NIOCore + +/// This is a reverse ``NIOCore/ByteToMessageHandler``. Instead of creating messages from incoming bytes +/// as the normal `ByteToMessageHandler` does, this `ReverseByteToMessageHandler` creates messages +/// from outgoing bytes. This is only important for testing in `EmbeddedChannel`s. +class ReverseByteToMessageHandler: ChannelOutboundHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = Decoder.InboundOut + + let processor: NIOSingleStepByteToMessageProcessor + + init(_ decoder: Decoder) { + self.processor = .init(decoder, maximumBufferSize: nil) + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let buffer = self.unwrapOutboundIn(data) + + do { + var messages = [Decoder.InboundOut]() + try self.processor.process(buffer: buffer) { message in + messages.append(message) + } + + for (index, message) in messages.enumerated() { + if index == messages.index(before: messages.endIndex) { + context.write(self.wrapOutboundOut(message), promise: promise) + } else { + context.write(self.wrapOutboundOut(message), promise: nil) + } + } + } catch { + context.fireErrorCaught(error) + } + } +} From 1017bca28627ccf90e99c4585c59ba64a013b1b8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 23 Sep 2021 09:56:33 +0200 Subject: [PATCH 022/292] Batch rows for consumption (#180) ### Motivation To allow faster processing of incoming `DataRow`s, we should batch them up in `channelRead` events and forward them as a batch for consumption in `PSQLRowStream`. This work is the foundation for AsyncSequence support in the future. ### Modifications - Extends `ExtendedQueryStateMachine` to use `RowStreamStateMachine` internally - Refactor `PSQLRowStream` to work with batches of rows. --- .../PostgresConnection+Database.swift | 46 ++- .../ConnectionStateMachine.swift | 173 +++++++--- .../ExtendedQueryStateMachine.swift | 143 ++++---- .../PostgresNIO/New/PSQLChannelHandler.swift | 122 +++---- Sources/PostgresNIO/New/PSQLRowStream.swift | 307 +++++++++++++----- .../PSQLIntegrationTests.swift | 108 +++--- .../ExtendedQueryStateMachineTests.swift | 28 +- .../ConnectionAction+TestUtils.swift | 8 +- .../PSQLBackendMessage+Equatable.swift | 8 + 9 files changed, 590 insertions(+), 353 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift index e48ac9ff..725f17d8 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift @@ -15,6 +15,25 @@ extension PostgresConnection: PostgresDatabase { switch command { case .query(let query, let binds, let onMetadata, let onRow): + resultFuture = self.underlying.query(query, binds, logger: logger).flatMap { stream in + let fields = stream.rowDescription.map { column in + PostgresMessage.RowDescription.Field( + name: column.name, + tableOID: UInt32(column.tableOID), + columnAttributeNumber: column.columnAttributeNumber, + dataType: PostgresDataType(UInt32(column.dataType.rawValue)), + dataTypeSize: column.dataTypeSize, + dataTypeModifier: column.dataTypeModifier, + formatCode: .init(psqlFormatCode: column.format) + ) + } + + let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) + return stream.iterateRowsWithoutBackpressureOption(lookupTable: lookupTable, onRow: onRow).map { _ in + onMetadata(PostgresQueryMetadata(string: stream.commandTag)!) + } + } + case .queryAll(let query, let binds, let onResult): resultFuture = self.underlying.query(query, binds, logger: logger).flatMap { rows in let fields = rows.rowDescription.map { column in PostgresMessage.RowDescription.Field( @@ -29,10 +48,18 @@ extension PostgresConnection: PostgresDatabase { } let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) - return rows.iterateRowsWithoutBackpressureOption(lookupTable: lookupTable, onRow: onRow).map { _ in - onMetadata(PostgresQueryMetadata(string: rows.commandTag)!) + return rows.all().map { allrows in + let r = allrows.map { psqlRow -> PostgresRow in + let columns = psqlRow.data.columns.map { + PostgresMessage.DataRow.Column(value: $0) + } + return PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) + } + + onResult(.init(metadata: PostgresQueryMetadata(string: rows.commandTag)!, rows: r)) } } + case .prepareQuery(let request): resultFuture = self.underlying.prepareStatement(request.query, with: request.name, logger: self.logger).map { request.prepared = PreparedQuery(underlying: $0, database: self) @@ -62,6 +89,9 @@ internal enum PostgresCommands: PostgresRequest { binds: [PostgresData], onMetadata: (PostgresQueryMetadata) -> () = { _ in }, onRow: (PostgresRow) throws -> ()) + case queryAll(query: String, + binds: [PostgresData], + onResult: (PostgresQueryResult) -> ()) case prepareQuery(request: PrepareQueryRequest) case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: (PostgresRow) throws -> ()) @@ -82,18 +112,12 @@ extension PSQLRowStream { func iterateRowsWithoutBackpressureOption(lookupTable: PostgresRow.LookupTable, onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { self.onRow { psqlRow in - let columns = psqlRow.data.columns.map { bytes in - PostgresMessage.DataRow.Column(value: bytes) + let columns = psqlRow.data.columns.map { + PostgresMessage.DataRow.Column(value: $0) } let row = PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) - - do { - try onRow(row) - return self.eventLoop.makeSucceededFuture(Void()) - } catch { - return self.eventLoop.makeFailedFuture(error) - } + try onRow(row) } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index dbeafa5d..1af28a3b 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -92,12 +92,9 @@ struct ConnectionStateMachine { // --- streaming actions // actions if query has requested next row but we are waiting for backend - case forwardRow(PSQLBackendMessage.DataRow, to: EventLoopPromise) - case forwardCommandComplete(CircularBuffer, commandTag: String, to: EventLoopPromise) - case forwardStreamError(PSQLError, to: EventLoopPromise, cleanupContext: CleanUpContext?) - // actions if query has not asked for next row but are pushing the final bytes to it - case forwardStreamErrorToCurrentQuery(PSQLError, read: Bool, cleanupContext: CleanUpContext?) - case forwardStreamCompletedToCurrentQuery(CircularBuffer, commandTag: String, read: Bool) + case forwardRows(CircularBuffer) + case forwardStreamComplete(CircularBuffer, commandTag: String) + case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?) // Prepare statement actions case sendParseDescribeSync(name: String, query: String) @@ -172,8 +169,10 @@ struct ConnectionStateMachine { switch self.state { case .initialized: preconditionFailure("How can a connection be closed, if it was never connected.") + case .closed: preconditionFailure("How can a connection be closed, if it is already closed.") + case .authenticated, .sslRequestSent, .sslNegotiated, @@ -185,10 +184,12 @@ struct ConnectionStateMachine { .prepareStatement, .closeCommand: return self.errorHappened(.uncleanShutdown) + case .error, .closing: self.state = .closed self.quiescingState = .notQuiescing return .fireChannelInactive + case .modifying: preconditionFailure("Invalid state") } @@ -199,8 +200,24 @@ struct ConnectionStateMachine { case .sslRequestSent: self.state = .sslNegotiated return .establishSSLConnection - default: + + case .initialized, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .prepareStatement, + .closeCommand, + .error, + .closing, + .closed: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") } } @@ -208,27 +225,77 @@ struct ConnectionStateMachine { switch self.state { case .sslRequestSent: return self.closeConnectionAndCleanup(.sslUnsupported) - default: + + case .initialized, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .prepareStatement, + .closeCommand, + .error, + .closing, + .closed: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") } } mutating func sslHandlerAdded() -> ConnectionAction { - guard case .sslNegotiated = self.state else { - preconditionFailure("Can only add a ssl handler after negotiation") + switch self.state { + case .initialized, + .sslRequestSent, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .prepareStatement, + .closeCommand, + .error, + .closing, + .closed: + preconditionFailure("Can only add a ssl handler after negotiation: \(self.state)") + + case .sslNegotiated: + self.state = .sslHandlerAdded + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") } - - self.state = .sslHandlerAdded - return .wait } mutating func sslEstablished() -> ConnectionAction { - guard case .sslHandlerAdded = self.state else { - preconditionFailure("Can only establish a ssl connection after adding a ssl handler") + switch self.state { + case .initialized, + .sslRequestSent, + .sslNegotiated, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .prepareStatement, + .closeCommand, + .error, + .closing, + .closed: + preconditionFailure("Can only establish a ssl connection after adding a ssl handler: \(self.state)") + + case .sslHandlerAdded: + self.state = .waitingToStartAuthentication + return .provideAuthenticationContext + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") } - - self.state = .waitingToStartAuthentication - return .provideAuthenticationContext } mutating func authenticationMessageReceived(_ message: PSQLBackendMessage.Authentication) -> ConnectionAction { @@ -518,6 +585,35 @@ struct ConnectionStateMachine { } } + mutating func channelReadComplete() -> ConnectionAction { + switch self.state { + case .initialized, + .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .prepareStatement, + .closeCommand, + .error, + .closing, + .closed: + return .wait + + case .extendedQuery(var extendedQuery, let connectionContext): + return self.avoidingStateMachineCoW { machine in + let action = extendedQuery.channelReadComplete() + machine.state = .extendedQuery(extendedQuery, connectionContext) + return machine.modify(with: action) + } + + case .modifying: + preconditionFailure("Invalid state") + } + } + mutating func readEventCaught() -> ConnectionAction { switch self.state { case .initialized: @@ -562,7 +658,6 @@ struct ConnectionStateMachine { preconditionFailure("How can we receive a read, if the connection is closed") case .modifying: preconditionFailure("Invalid state") - } } @@ -714,13 +809,13 @@ struct ConnectionStateMachine { preconditionFailure("Unimplemented") } - mutating func consumeNextQueryRow(promise: EventLoopPromise) -> ConnectionAction { + mutating func requestQueryRows() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { preconditionFailure("Tried to consume next row, without active query") } return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.consumeNextRow(promise: promise) + let action = queryState.requestQueryRows() machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } @@ -783,18 +878,15 @@ struct ConnectionStateMachine { .sendBindExecuteSync, .succeedQuery, .succeedQueryNoRowsComming, - .forwardRow, - .forwardCommandComplete, - .forwardStreamCompletedToCurrentQuery, + .forwardRows, + .forwardStreamComplete, .wait, .read: preconditionFailure("Expecting only failure actions if an error happened") case .failQuery(let queryContext, with: let error): return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) - case .forwardStreamError(let error, to: let promise): - return .forwardStreamError(error, to: promise, cleanupContext: cleanupContext) - case .forwardStreamErrorToCurrentQuery(let error, read: let read): - return .forwardStreamErrorToCurrentQuery(error, read: read, cleanupContext: cleanupContext) + case .forwardStreamError(let error, let read): + return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) } case .prepareStatement(var prepareStateMachine, _): let cleanupContext = self.setErrorAndCreateCleanupContext(error) @@ -1025,18 +1117,13 @@ extension ConnectionStateMachine { return .succeedQuery(requestContext, columns: columns) case .succeedQueryNoRowsComming(let requestContext, let commandTag): return .succeedQueryNoRowsComming(requestContext, commandTag: commandTag) - case .forwardRow(let data, to: let promise): - return .forwardRow(data, to: promise) - case .forwardCommandComplete(let buffer, let commandTag, to: let promise): - return .forwardCommandComplete(buffer, commandTag: commandTag, to: promise) - case .forwardStreamError(let error, to: let promise): - let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) - return .forwardStreamError(error, to: promise, cleanupContext: cleanupContext) - case .forwardStreamErrorToCurrentQuery(let error, let read): + case .forwardRows(let buffer): + return .forwardRows(buffer) + case .forwardStreamComplete(let buffer, let commandTag): + return .forwardStreamComplete(buffer, commandTag: commandTag) + case .forwardStreamError(let error, let read): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) - return .forwardStreamErrorToCurrentQuery(error, read: read, cleanupContext: cleanupContext) - case .forwardStreamCompletedToCurrentQuery(let buffer, let commandTag, let read): - return .forwardStreamCompletedToCurrentQuery(buffer, commandTag: commandTag, read: read) + return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) case .read: return .read case .wait: @@ -1104,14 +1191,6 @@ extension ConnectionStateMachine { } } -enum StateMachineStreamNextResult { - /// the next row - case row(PSQLBackendMessage.DataRow) - - /// the query has completed, all remaining rows and the command completion tag - case complete(CircularBuffer, commandTag: String) -} - struct SendPrepareStatement { let name: String let query: String diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 36b69f83..4818ca19 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -14,8 +14,7 @@ struct ExtendedQueryStateMachine { /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) - case bufferingRows([PSQLBackendMessage.RowDescription.Column], CircularBuffer, readOnEmpty: Bool) - case waitingForNextRow([PSQLBackendMessage.RowDescription.Column], CircularBuffer, EventLoopPromise) + case streaming([PSQLBackendMessage.RowDescription.Column], RowStreamStateMachine) case commandComplete(commandTag: String) case error(PSQLError) @@ -34,12 +33,9 @@ struct ExtendedQueryStateMachine { // --- streaming actions // actions if query has requested next row but we are waiting for backend - case forwardRow(PSQLBackendMessage.DataRow, to: EventLoopPromise) - case forwardCommandComplete(CircularBuffer, commandTag: String, to: EventLoopPromise) - case forwardStreamError(PSQLError, to: EventLoopPromise) - // actions if query has not asked for next row but are pushing the final bytes to it - case forwardStreamErrorToCurrentQuery(PSQLError, read: Bool) - case forwardStreamCompletedToCurrentQuery(CircularBuffer, commandTag: String, read: Bool) + case forwardRows(CircularBuffer) + case forwardStreamComplete(CircularBuffer, commandTag: String) + case forwardStreamError(PSQLError, read: Bool) case read case wait @@ -137,7 +133,7 @@ struct ExtendedQueryStateMachine { switch self.state { case .rowDescriptionReceived(let context, let columns): return self.avoidingStateMachineCoW { state -> Action in - state = .bufferingRows(columns, CircularBuffer(), readOnEmpty: false) + state = .streaming(columns, .init()) return .succeedQuery(context, columns: columns) } case .noDataMessageReceived(let queryContext): @@ -150,8 +146,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, - .bufferingRows, - .waitingForNextRow, + .streaming, .commandComplete, .error: return self.setAndFireError(.unexpectedBackendMessage(.bindComplete)) @@ -162,7 +157,7 @@ struct ExtendedQueryStateMachine { mutating func dataRowReceived(_ dataRow: PSQLBackendMessage.DataRow) -> Action { switch self.state { - case .bufferingRows(let columns, var buffer, let readOnEmpty): + case .streaming(let columns, var demandStateMachine): // When receiving a data row, we must ensure that the data row column count // matches the previously received row description column count. guard dataRow.columns.count == columns.count else { @@ -170,24 +165,11 @@ struct ExtendedQueryStateMachine { } return self.avoidingStateMachineCoW { state -> Action in - buffer.append(dataRow) - state = .bufferingRows(columns, buffer, readOnEmpty: readOnEmpty) + demandStateMachine.receivedRow(dataRow) + state = .streaming(columns, demandStateMachine) return .wait } - case .waitingForNextRow(let columns, let buffer, let promise): - // When receiving a data row, we must ensure that the data row column count - // matches the previously received row description column count. - guard dataRow.columns.count == columns.count else { - return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) - } - - return self.avoidingStateMachineCoW { state -> Action in - precondition(buffer.isEmpty, "Expected the buffer to be empty") - state = .bufferingRows(columns, buffer, readOnEmpty: false) - return .forwardRow(dataRow, to: promise) - } - case .initialized, .parseDescribeBindExecuteSyncSent, .parseCompleteReceived, @@ -211,17 +193,10 @@ struct ExtendedQueryStateMachine { return .succeedQueryNoRowsComming(context, commandTag: commandTag) } - case .bufferingRows(_, let buffer, let readOnEmpty): + case .streaming(_, var demandStateMachine): return self.avoidingStateMachineCoW { state -> Action in state = .commandComplete(commandTag: commandTag) - return .forwardStreamCompletedToCurrentQuery(buffer, commandTag: commandTag, read: readOnEmpty) - } - - case .waitingForNextRow(_, let buffer, let promise): - return self.avoidingStateMachineCoW { state -> Action in - precondition(buffer.isEmpty, "Expected the buffer to be empty") - state = .commandComplete(commandTag: commandTag) - return .forwardCommandComplete(buffer, commandTag: commandTag, to: promise) + return .forwardStreamComplete(demandStateMachine.end(), commandTag: commandTag) } case .initialized, @@ -254,9 +229,7 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .rowDescriptionReceived, .noDataMessageReceived: return self.setAndFireError(error) - case .bufferingRows: - return self.setAndFireError(error) - case .waitingForNextRow: + case .streaming: return self.setAndFireError(error) case .commandComplete: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) @@ -282,20 +255,18 @@ struct ExtendedQueryStateMachine { // MARK: Customer Actions - mutating func consumeNextRow(promise: EventLoopPromise) -> Action { + mutating func requestQueryRows() -> Action { switch self.state { - case .waitingForNextRow: - preconditionFailure("Too greedy. `consumeNextRow()` only needs to be called once.") - - case .bufferingRows(let columns, var buffer, let readOnEmpty): + case .streaming(let columns, var demandStateMachine): return self.avoidingStateMachineCoW { state -> Action in - guard let row = buffer.popFirst() else { - state = .waitingForNextRow(columns, buffer, promise) - return readOnEmpty ? .read : .wait + let action = demandStateMachine.demandMoreResponseBodyParts() + state = .streaming(columns, demandStateMachine) + switch action { + case .read: + return .read + case .wait: + return .wait } - - state = .bufferingRows(columns, buffer, readOnEmpty: readOnEmpty) - return .forwardRow(row, to: promise) } case .initialized, @@ -316,29 +287,56 @@ struct ExtendedQueryStateMachine { // MARK: Channel actions + mutating func channelReadComplete() -> Action { + switch self.state { + case .initialized, + .commandComplete, + .error, + .parseDescribeBindExecuteSyncSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .noDataMessageReceived, + .rowDescriptionReceived, + .bindCompleteReceived: + return .wait + + case .streaming(let columns, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + let rows = demandStateMachine.channelReadComplete() + state = .streaming(columns, demandStateMachine) + switch rows { + case .some(let rows): + return .forwardRows(rows) + case .none: + return .wait + } + } + + case .modifying: + preconditionFailure("Invalid state") + } + } + mutating func readEventCaught() -> Action { switch self.state { - case .parseDescribeBindExecuteSyncSent: - return .read - case .parseCompleteReceived: - return .read - case .parameterDescriptionReceived: - return .read - case .noDataMessageReceived: - return .read - case .rowDescriptionReceived: - return .read - case .bindCompleteReceived: + case .parseDescribeBindExecuteSyncSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .noDataMessageReceived, + .rowDescriptionReceived, + .bindCompleteReceived: return .read - case .bufferingRows(let columns, let buffer, _): + case .streaming(let columns, var demandStateMachine): return self.avoidingStateMachineCoW { state -> Action in - state = .bufferingRows(columns, buffer, readOnEmpty: true) - return .wait + let action = demandStateMachine.read() + state = .streaming(columns, demandStateMachine) + switch action { + case .wait: + return .wait + case .read: + return .read + } } - case .waitingForNextRow: - // we are in the stream and the consumer has already asked us for more rows, - // therefore we need to read! - return .read case .initialized, .commandComplete, .error: @@ -363,12 +361,11 @@ struct ExtendedQueryStateMachine { .bindCompleteReceived(let context): self.state = .error(error) return .failQuery(context, with: error) - case .bufferingRows(_, _, readOnEmpty: let readOnEmpty): - self.state = .error(error) - return .forwardStreamErrorToCurrentQuery(error, read: readOnEmpty) - case .waitingForNextRow(_, _, let promise): + + case .streaming: self.state = .error(error) - return .forwardStreamError(error, to: promise) + return .forwardStreamError(error, read: false) + case .commandComplete, .error: preconditionFailure(""" This state must not be reached. If the query `.isComplete`, the diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index e0f71114..c1606497 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -18,7 +18,12 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.logger.trace("Connection state changed", metadata: [.connectionState: "\(self.state)"]) } } - private var currentQuery: PSQLRowStream? + + /// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed). + /// + /// The context is captured in `handlerAdded` and released` in `handlerRemoved` + private var handlerContext: ChannelHandlerContext! + private var rowStream: PSQLRowStream? private let authentificationConfiguration: PSQLConnection.Configuration.Authentication? private let configureSSLCallback: ((Channel) throws -> Void)? @@ -52,11 +57,16 @@ final class PSQLChannelHandler: ChannelDuplexHandler { // MARK: Handler lifecycle func handlerAdded(context: ChannelHandlerContext) { + self.handlerContext = context if context.channel.isActive { self.connected(context: context) } } + func handlerRemoved(context: ChannelHandlerContext) { + self.handlerContext = nil + } + // MARK: Channel handler incoming func channelActive(context: ChannelHandlerContext) { @@ -131,6 +141,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.run(action, with: context) } + func channelReadComplete(context: ChannelHandlerContext) { + let action = self.state.channelReadComplete() + self.run(action, with: context) + } + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { self.logger.trace("User inbound event received", metadata: [ .userEvent: "\(event)" @@ -224,38 +239,30 @@ final class PSQLChannelHandler: ChannelDuplexHandler { if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } - case .forwardRow(let row, to: let promise): - promise.succeed(.row(row)) - case .forwardCommandComplete(let buffer, let commandTag, to: let promise): - promise.succeed(.complete(buffer, commandTag: commandTag)) - self.currentQuery = nil - case .forwardStreamError(let error, to: let promise, let cleanupContext): - promise.fail(error) - self.currentQuery = nil - if let cleanupContext = cleanupContext { - self.closeConnectionAndCleanup(cleanupContext, context: context) - } - case .forwardStreamErrorToCurrentQuery(let error, let read, let cleanupContext): - guard let query = self.currentQuery else { - preconditionFailure("Expected to have an open query at this point") + + case .forwardRows(let rows): + self.rowStream!.receive(rows) + + case .forwardStreamComplete(let buffer, let commandTag): + guard let rowStream = self.rowStream else { + preconditionFailure("Expected to have a row stream here.") } - query.finalForward(.failure(error)) - self.currentQuery = nil - if read { - context.read() + self.rowStream = nil + if buffer.count > 0 { + rowStream.receive(buffer) } + rowStream.receive(completion: .success(commandTag)) + + + case .forwardStreamError(let error, let read, let cleanupContext): + self.rowStream!.receive(completion: .failure(error)) + self.rowStream = nil if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) - } - case .forwardStreamCompletedToCurrentQuery(let buffer, commandTag: let commandTag, let read): - guard let query = self.currentQuery else { - preconditionFailure("Expected to have an open query at this point") - } - query.finalForward(.success((buffer, commandTag))) - self.currentQuery = nil - if read { + } else if read { context.read() } + case .provideAuthenticationContext: context.fireUserInboundEventTriggered(PSQLEvent.readyForStartup) @@ -363,7 +370,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { query: String, context: ChannelHandlerContext) { - precondition(self.currentQuery == nil, "Expected to not have an open query at this point") + precondition(self.rowStream == nil, "Expected to not have an open stream at this point") let parse = PSQLFrontendMessage.Parse( preparedStatementName: statementName, query: query, @@ -395,7 +402,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { query: String, binds: [PSQLEncodable], context: ChannelHandlerContext) { - precondition(self.currentQuery == nil, "Expected to not have an open query at this point") + precondition(self.rowStream == nil, "Expected to not have an open stream at this point") let unnamedStatementName = "" let parse = PSQLFrontendMessage.Parse( preparedStatementName: unnamedStatementName, @@ -406,11 +413,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler { preparedStatementName: unnamedStatementName, parameters: binds) - context.write(.parse(parse), promise: nil) - context.write(.describe(.preparedStatement("")), promise: nil) - context.write(.bind(bind), promise: nil) - context.write(.execute(.init(portalName: "")), promise: nil) - context.write(.sync, promise: nil) + context.write(wrapOutboundOut(.parse(parse)), promise: nil) + context.write(wrapOutboundOut(.describe(.preparedStatement(""))), promise: nil) + context.write(wrapOutboundOut(.bind(bind)), promise: nil) + context.write(wrapOutboundOut(.execute(.init(portalName: ""))), promise: nil) + context.write(wrapOutboundOut(.sync), promise: nil) context.flush() } @@ -419,29 +426,13 @@ final class PSQLChannelHandler: ChannelDuplexHandler { columns: [PSQLBackendMessage.RowDescription.Column], context: ChannelHandlerContext) { - let eventLoop = context.channel.eventLoop - func consumeNextRow() -> EventLoopFuture { - let promise = eventLoop.makePromise(of: StateMachineStreamNextResult.self) - let action = self.state.consumeNextQueryRow(promise: promise) - self.run(action, with: context) - return promise.futureResult - } let rows = PSQLRowStream( rowDescription: columns, queryContext: queryContext, eventLoop: context.channel.eventLoop, - cancel: { - let action = self.state.cancelQueryStream() - self.run(action, with: context) - }, next: { - guard eventLoop.inEventLoop else { - return eventLoop.flatSubmit { consumeNextRow() } - } - - return consumeNextRow() - }) + rowSource: .stream(self)) - self.currentQuery = rows + self.rowStream = rows queryContext.promise.succeed(rows) } @@ -450,17 +441,12 @@ final class PSQLChannelHandler: ChannelDuplexHandler { commandTag: String, context: ChannelHandlerContext) { - let eventLoop = context.channel.eventLoop let rows = PSQLRowStream( rowDescription: [], queryContext: queryContext, eventLoop: context.channel.eventLoop, - cancel: { - // ignore... - }, next: { - let emptyBuffer = CircularBuffer(initialCapacity: 0) - return eventLoop.makeSucceededFuture(.complete(emptyBuffer, commandTag: commandTag)) - }) + rowSource: .noRows(.success(commandTag)) + ) queryContext.promise.succeed(rows) } @@ -489,6 +475,23 @@ final class PSQLChannelHandler: ChannelDuplexHandler { } } +extension PSQLChannelHandler: PSQLRowsDataSource { + func request(for stream: PSQLRowStream) { + guard self.rowStream === stream else { + return + } + let action = self.state.requestQueryRows() + self.run(action, with: self.handlerContext!) + } + + func cancel(for stream: PSQLRowStream) { + guard self.rowStream === stream else { + return + } + // we ignore this right now :) + } +} + extension ChannelHandlerContext { func write(_ psqlMessage: PSQLFrontendMessage, promise: EventLoopPromise? = nil) { self.write(NIOAny(psqlMessage), promise: promise) @@ -517,4 +520,3 @@ extension AuthContext { replication: .false) } } - diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 0f28a527..768255fb 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -3,18 +3,25 @@ import Logging final class PSQLRowStream { + enum RowSource { + case stream(PSQLRowsDataSource) + case noRows(Result) + } + let eventLoop: EventLoop let logger: Logger private enum UpstreamState { - case streaming(next: () -> EventLoopFuture, cancel: () -> ()) - case finished(remaining: CircularBuffer, commandTag: String) + case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) + case finished(buffer: CircularBuffer, commandTag: String) case failure(Error) case consumed(Result) + case modifying } private enum DownstreamState { - case waitingForNext + case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise) + case waitingForAll(EventLoopPromise<[PSQLRow]>) case consuming } @@ -27,11 +34,19 @@ final class PSQLRowStream { init(rowDescription: [PSQLBackendMessage.RowDescription.Column], queryContext: ExtendedQueryContext, eventLoop: EventLoop, - cancel: @escaping () -> (), - next: @escaping () -> EventLoopFuture) + rowSource: RowSource) { - self.upstreamState = .streaming(next: next, cancel: cancel) + let buffer = CircularBuffer() + self.downstreamState = .consuming + switch rowSource { + case .stream(let dataSource): + self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + case .noRows(.success(let commandTag)): + self.upstreamState = .finished(buffer: .init(), commandTag: commandTag) + case .noRows(.failure(let error)): + self.upstreamState = .failure(error) + } self.eventLoop = eventLoop self.logger = queryContext.logger @@ -45,56 +60,123 @@ final class PSQLRowStream { } self.lookupTable = lookup } - - func next() -> EventLoopFuture { - guard self.eventLoop.inEventLoop else { + + func all() -> EventLoopFuture<[PSQLRow]> { + if self.eventLoop.inEventLoop { + return self.all0() + } else { return self.eventLoop.flatSubmit { - self.next() + self.all0() } } + } + + private func all0() -> EventLoopFuture<[PSQLRow]> { + self.eventLoop.preconditionInEventLoop() - assert(self.downstreamState == .consuming) + guard case .consuming = self.downstreamState else { + preconditionFailure("Invalid state") + } switch self.upstreamState { - case .streaming(let upstreamNext, _): - return upstreamNext().map { payload -> PSQLRow? in - self.downstreamState = .consuming - switch payload { - case .row(let data): - return PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) - case .complete(var buffer, let commandTag): - if let data = buffer.popFirst() { - self.upstreamState = .finished(remaining: buffer, commandTag: commandTag) - return PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) - } - - self.upstreamState = .consumed(.success(commandTag)) - return nil - } - }.flatMapErrorThrowing { error in - // if we have an error upstream that, we pass through here, we need to set - // our internal state - self.upstreamState = .consumed(.failure(error)) - throw error - } + case .streaming(_, let dataSource): + dataSource.request(for: self) + let promise = self.eventLoop.makePromise(of: [PSQLRow].self) + self.downstreamState = .waitingForAll(promise) + return promise.futureResult - case .finished(remaining: var buffer, commandTag: let commandTag): - self.downstreamState = .consuming - if let data = buffer.popFirst() { - self.upstreamState = .finished(remaining: buffer, commandTag: commandTag) - let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) - return self.eventLoop.makeSucceededFuture(row) + case .finished(let buffer, let commandTag): + self.upstreamState = .modifying + + let rows = buffer.map { + PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) } + self.downstreamState = .consuming self.upstreamState = .consumed(.success(commandTag)) - return self.eventLoop.makeSucceededFuture(nil) + return self.eventLoop.makeSucceededFuture(rows) + + case .consumed: + preconditionFailure("We already signaled, that the stream has completed, why are we asked again?") + + case .modifying: + preconditionFailure("Invalid state") case .failure(let error): self.upstreamState = .consumed(.failure(error)) return self.eventLoop.makeFailedFuture(error) + } + } + + func onRow(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture { + if self.eventLoop.inEventLoop { + return self.onRow0(onRow) + } else { + return self.eventLoop.flatSubmit { + self.onRow0(onRow) + } + } + } + + private func onRow0(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture { + self.eventLoop.preconditionInEventLoop() + + switch self.upstreamState { + case .streaming(var buffer, let dataSource): + let promise = self.eventLoop.makePromise(of: Void.self) + do { + for data in buffer { + let row = PSQLRow( + data: data, + lookupTable: self.lookupTable, + columns: self.rowDescription, + jsonDecoder: self.jsonDecoder + ) + try onRow(row) + } + + buffer.removeAll() + self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + self.downstreamState = .iteratingRows(onRow: onRow, promise) + // immediately request more + dataSource.request(for: self) + } catch { + self.upstreamState = .failure(error) + dataSource.cancel(for: self) + promise.fail(error) + } + + return promise.futureResult + + case .finished(let buffer, let commandTag): + do { + for data in buffer { + let row = PSQLRow( + data: data, + lookupTable: self.lookupTable, + columns: self.rowDescription, + jsonDecoder: self.jsonDecoder + ) + try onRow(row) + } + + self.upstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consuming + return self.eventLoop.makeSucceededVoidFuture() + } catch { + self.upstreamState = .consumed(.failure(error)) + return self.eventLoop.makeFailedFuture(error) + } case .consumed: preconditionFailure("We already signaled, that the stream has completed, why are we asked again?") + + case .modifying: + preconditionFailure("Invalid state") + + case .failure(let error): + self.upstreamState = .consumed(.failure(error)) + return self.eventLoop.makeFailedFuture(error) } } @@ -104,40 +186,106 @@ final class PSQLRowStream { ]) } - internal func finalForward(_ finalForward: Result<(CircularBuffer, commandTag: String), PSQLError>?) { - switch finalForward { - case .some(.success((let buffer, commandTag: let commandTag))): - guard case .streaming = self.upstreamState else { - preconditionFailure("Expected to be streaming up until now") + internal func receive(_ newRows: CircularBuffer) { + precondition(!newRows.isEmpty, "Expected to get rows!") + self.eventLoop.preconditionInEventLoop() + self.logger.trace("Row stream received rows", metadata: [ + "row_count": "\(newRows.count)" + ]) + + guard case .streaming(var buffer, let dataSource) = self.upstreamState else { + preconditionFailure("Invalid state") + } + + switch self.downstreamState { + case .iteratingRows(let onRow, let promise): + precondition(buffer.isEmpty) + do { + for data in newRows { + let row = PSQLRow( + data: data, + lookupTable: self.lookupTable, + columns: self.rowDescription, + jsonDecoder: self.jsonDecoder + ) + try onRow(row) + } + // immediately request more + dataSource.request(for: self) + } catch { + dataSource.cancel(for: self) + self.upstreamState = .failure(error) + promise.fail(error) + return + } + case .waitingForAll: + self.upstreamState = .modifying + buffer.append(contentsOf: newRows) + self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + + // immediately request more + dataSource.request(for: self) + + case .consuming: + // this might happen, if the query has finished while the user is consuming data + // we don't need to ask for more since the user is consuming anyway + self.upstreamState = .modifying + buffer.append(contentsOf: newRows) + self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + } + } + + internal func receive(completion result: Result) { + self.eventLoop.preconditionInEventLoop() + + guard case .streaming(let oldBuffer, _) = self.upstreamState else { + preconditionFailure("Invalid state") + } + + switch self.downstreamState { + case .iteratingRows(_, let promise): + precondition(oldBuffer.isEmpty) + self.downstreamState = .consuming + self.upstreamState = .consumed(result) + switch result { + case .success: + promise.succeed(()) + case .failure(let error): + promise.fail(error) } - self.upstreamState = .finished(remaining: buffer, commandTag: commandTag) - case .some(.failure(let error)): - guard case .streaming = self.upstreamState else { - preconditionFailure("Expected to be streaming up until now") + + + case .consuming: + switch result { + case .success(let commandTag): + self.upstreamState = .finished(buffer: oldBuffer, commandTag: commandTag) + case .failure(let error): + self.upstreamState = .failure(error) } - self.upstreamState = .failure(error) - case .none: - switch self.upstreamState { - case .consumed: - break - case .finished: - break - case .failure: - preconditionFailure("Invalid state") - case .streaming: - preconditionFailure("Invalid state") + + case .waitingForAll(let promise): + switch result { + case .failure(let error): + self.upstreamState = .consumed(.failure(error)) + promise.fail(error) + case .success(let commandTag): + let rows = oldBuffer.map { + PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + } + self.upstreamState = .consumed(.success(commandTag)) + promise.succeed(rows) } } } func cancel() { - guard case .streaming(_, let cancel) = self.upstreamState else { + guard case .streaming(_, let dataSource) = self.upstreamState else { // We don't need to cancel any upstream resource. All needed data is already - // included in this + // included in this return } - cancel() + dataSource.cancel(for: self) } var commandTag: String { @@ -146,32 +294,11 @@ final class PSQLRowStream { } return commandTag } - - func onRow(_ onRow: @escaping (PSQLRow) -> EventLoopFuture) -> EventLoopFuture { - let promise = self.eventLoop.makePromise(of: Void.self) - - func consumeNext(promise: EventLoopPromise) { - self.next().whenComplete { result in - switch result { - case .success(.some(let row)): - onRow(row).whenComplete { result in - switch result { - case .success: - consumeNext(promise: promise) - case .failure(let error): - promise.fail(error) - } - } - case .success(.none): - promise.succeed(Void()) - case .failure(let error): - promise.fail(error) - } - } - } - - consumeNext(promise: promise) - - return promise.futureResult - } +} + +protocol PSQLRowsDataSource { + + func request(for stream: PSQLRowStream) + func cancel(for stream: PSQLRowStream) + } diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index c7112a5b..011d8c70 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -56,12 +56,11 @@ final class IntegrationTests: XCTestCase { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT version()", logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var version: String? - XCTAssertNoThrow(version = try row?.decode(column: 0, as: String.self)) + XCTAssertNoThrow(version = try rows?.first?.decode(column: 0, as: String.self)) XCTAssertEqual(version?.contains("PostgreSQL"), true) - XCTAssertNil(try stream?.next().wait()) } func testQuery10kItems() { @@ -76,27 +75,20 @@ final class IntegrationTests: XCTestCase { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest).wait()) - var expected: Int64 = 1 + var received: Int64 = 0 XCTAssertNoThrow(try stream?.onRow { row in - let promise = eventLoop.makePromise(of: Void.self) - func workaround() { var number: Int64? XCTAssertNoThrow(number = try row.decode(column: 0, as: Int64.self)) - XCTAssertEqual(number, expected) - expected += 1 - } - - eventLoop.execute { - workaround() - promise.succeed(()) + received += 1 + XCTAssertEqual(number, received) } - return promise.futureResult + workaround() }.wait()) - XCTAssertEqual(expected, 10001) + XCTAssertEqual(received, 10000) } func test1kRoundTrips() { @@ -111,12 +103,11 @@ final class IntegrationTests: XCTestCase { for _ in 0..<1_000 { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT version()", logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var version: String? - XCTAssertNoThrow(version = try row?.decode(column: 0, as: String.self)) + XCTAssertNoThrow(version = try rows?.first?.decode(column: 0, as: String.self)) XCTAssertEqual(version?.contains("PostgreSQL"), true) - XCTAssertNil(try stream?.next().wait()) } } @@ -131,12 +122,11 @@ final class IntegrationTests: XCTestCase { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT $1::TEXT as foo", ["hello"], logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var foo: String? - XCTAssertNoThrow(foo = try row?.decode(column: 0, as: String.self)) + XCTAssertNoThrow(foo = try rows?.first?.decode(column: 0, as: String.self)) XCTAssertEqual(foo, "hello") - XCTAssertNil(try stream?.next().wait()) } func testDecodeIntegers() { @@ -162,8 +152,10 @@ final class IntegrationTests: XCTestCase { 9223372036854775807::BIGINT as bigint_max """, logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) + let row = rows?.first XCTAssertEqual(try row?.decode(column: "smallint", as: Int16.self), 1) XCTAssertEqual(try row?.decode(column: "smallint_min", as: Int16.self), -32_767) @@ -174,8 +166,6 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(try row?.decode(column: "bigint", as: Int64.self), 1) XCTAssertEqual(try row?.decode(column: "bigint_min", as: Int64.self), -9_223_372_036_854_775_807) XCTAssertEqual(try row?.decode(column: "bigint_max", as: Int64.self), 9_223_372_036_854_775_807) - - XCTAssertNil(try stream?.next().wait()) } func testEncodeAndDecodeIntArray() { @@ -191,11 +181,10 @@ final class IntegrationTests: XCTestCase { let array: [Int64] = [1, 2, 3] XCTAssertNoThrow(stream = try conn?.query("SELECT $1::int8[] as array", [array], logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) - - XCTAssertEqual(try row?.decode(column: "array", as: [Int64].self), array) - XCTAssertNil(try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(try rows?.first?.decode(column: "array", as: [Int64].self), array) } func testDecodeEmptyIntegerArray() { @@ -210,11 +199,10 @@ final class IntegrationTests: XCTestCase { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT '{}'::int[] as array", logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) - - XCTAssertEqual(try row?.decode(column: "array", as: [Int64].self), []) - XCTAssertNil(try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(try rows?.first?.decode(column: "array", as: [Int64].self), []) } func testDoubleArraySerialization() { @@ -230,11 +218,10 @@ final class IntegrationTests: XCTestCase { let doubles: [Double] = [3.14, 42] XCTAssertNoThrow(stream = try conn?.query("SELECT $1::double precision[] as doubles", [doubles], logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) - - XCTAssertEqual(try row?.decode(column: "doubles", as: [Double].self), doubles) - XCTAssertNil(try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(try rows?.first?.decode(column: "doubles", as: [Double].self), doubles) } func testDecodeDates() { @@ -254,14 +241,14 @@ final class IntegrationTests: XCTestCase { '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz """, logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) + let row = rows?.first XCTAssertEqual(try row?.decode(column: "date", as: Date.self).description, "2016-01-18 00:00:00 +0000") XCTAssertEqual(try row?.decode(column: "timestamp", as: Date.self).description, "2016-01-18 01:02:03 +0000") XCTAssertEqual(try row?.decode(column: "timestamptz", as: Date.self).description, "2016-01-18 00:20:03 +0000") - - XCTAssertNil(try stream?.next().wait()) } func testDecodeUUID() { @@ -278,12 +265,11 @@ final class IntegrationTests: XCTestCase { SELECT '2c68f645-9ca6-468b-b193-ee97f241c2f8'::UUID as uuid """, logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) - - XCTAssertEqual(try row?.decode(column: "uuid", as: UUID.self), UUID(uuidString: "2c68f645-9ca6-468b-b193-ee97f241c2f8")) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) - XCTAssertNil(try stream?.next().wait()) + XCTAssertEqual(try rows?.first?.decode(column: "uuid", as: UUID.self), UUID(uuidString: "2c68f645-9ca6-468b-b193-ee97f241c2f8")) } func testRoundTripJSONB() { @@ -306,14 +292,13 @@ final class IntegrationTests: XCTestCase { select $1::jsonb as jsonb """, [Object(foo: 1, bar: 2)], logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) var result: Object? - XCTAssertNoThrow(result = try row?.decode(column: "jsonb", as: Object.self)) + XCTAssertNoThrow(result = try rows?.first?.decode(column: "jsonb", as: Object.self)) XCTAssertEqual(result?.foo, 1) XCTAssertEqual(result?.bar, 2) - - XCTAssertNil(try stream?.next().wait()) } do { @@ -322,14 +307,13 @@ final class IntegrationTests: XCTestCase { select $1::json as json """, [Object(foo: 1, bar: 2)], logger: .psqlTest).wait()) - var row: PSQLRow? - XCTAssertNoThrow(row = try stream?.next().wait()) + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) var result: Object? - XCTAssertNoThrow(result = try row?.decode(column: "json", as: Object.self)) + XCTAssertNoThrow(result = try rows?.first?.decode(column: "json", as: Object.self)) XCTAssertEqual(result?.foo, 1) XCTAssertEqual(result?.bar, 2) - - XCTAssertNil(try stream?.next().wait()) } } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index ea457bd5..e1076a6e 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -50,15 +50,31 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) - let rowContent = ByteBuffer(string: "test") - XCTAssertEqual(state.dataRowReceived(.init(columns: [rowContent])), .wait) + let row1: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test1")] + XCTAssertEqual(state.dataRowReceived(row1), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.requestQueryRows(), .read) - let rowPromise = EmbeddedEventLoop().makePromise(of: StateMachineStreamNextResult.self) - rowPromise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - XCTAssertEqual(state.consumeNextQueryRow(promise: rowPromise), .forwardRow(.init(columns: [rowContent]), to: rowPromise)) + let row2: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test2")] + let row3: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test3")] + let row4: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test4")] + XCTAssertEqual(state.dataRowReceived(row2), .wait) + XCTAssertEqual(state.dataRowReceived(row3), .wait) + XCTAssertEqual(state.dataRowReceived(row4), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row2, row3, row4])) + XCTAssertEqual(state.requestQueryRows(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) - XCTAssertEqual(state.commandCompletedReceived("SELECT 1"), .forwardStreamCompletedToCurrentQuery(CircularBuffer(), commandTag: "SELECT 1", read: true)) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + let row5: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test5")] + let row6: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test6")] + XCTAssertEqual(state.dataRowReceived(row5), .wait) + XCTAssertEqual(state.dataRowReceived(row6), .wait) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index dc7aaa7b..c88d112f 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -67,10 +67,10 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { return lhsContext === rhsContext && lhsRowDescription == rhsRowDescription case (.failQuery(let lhsContext, let lhsError, let lhsCleanupContext), .failQuery(let rhsContext, let rhsError, let rhsCleanupContext)): return lhsContext === rhsContext && lhsError == rhsError && lhsCleanupContext == rhsCleanupContext - case (.forwardRow(let lhsColumns, let lhsPromise), .forwardRow(let rhsColumns, let rhsPromise)): - return lhsColumns == rhsColumns && lhsPromise.futureResult === rhsPromise.futureResult - case (.forwardStreamCompletedToCurrentQuery(let lhsBuffer, let lhsCommandTag, let lhsRead), .forwardStreamCompletedToCurrentQuery(let rhsBuffer, let rhsCommandTag, let rhsRead)): - return lhsBuffer == rhsBuffer && lhsCommandTag == rhsCommandTag && lhsRead == rhsRead + case (.forwardRows(let lhsRows), .forwardRows(let rhsRows)): + return lhsRows == rhsRows + case (.forwardStreamComplete(let lhsBuffer, let lhsCommandTag), .forwardStreamComplete(let rhsBuffer, let rhsCommandTag)): + return lhsBuffer == rhsBuffer && lhsCommandTag == rhsCommandTag case (.sendParseDescribeSync(let lhsName, let lhsQuery), .sendParseDescribeSync(let rhsName, let rhsQuery)): return lhsName == rhsName && lhsQuery == rhsQuery case (.succeedPreparedStatementCreation(let lhsContext, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsContext, let rhsRowDescription)): diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift index 436c7aa9..8434e761 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift @@ -47,3 +47,11 @@ extension PSQLBackendMessage: Equatable { } } } + +extension PSQLBackendMessage.DataRow: ExpressibleByArrayLiteral { + public typealias ArrayLiteralElement = ByteBuffer + + public init(arrayLiteral elements: ByteBuffer...) { + self.init(columns: elements) + } +} From 6611ee128c9c45dba4e608a92ea523eceed6cdba Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 23 Sep 2021 10:00:16 +0200 Subject: [PATCH 023/292] Depend on Swift Crypto `"1.0.0" ..< "3.0.0"` (#183) --- Package.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index c46089e0..97966ead 100644 --- a/Package.swift +++ b/Package.swift @@ -15,7 +15,7 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.32.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), - .package(url: "https://github.com/apple/swift-crypto.git", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), ], From 28ab2df3674636d492d61fca657c11f470983580 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 23 Sep 2021 10:50:57 +0200 Subject: [PATCH 024/292] Move Decoding into PSQLChannelHandler (#182) --- .../PostgresNIO/New/PSQLChannelHandler.swift | 105 ++++++++++-------- Sources/PostgresNIO/New/PSQLConnection.swift | 5 +- ...wift => ReverseByteToMessageHandler.swift} | 0 .../ReverseMessageToByteHandler.swift | 32 ++++++ .../New/PSQLChannelHandlerTests.swift | 26 ++++- 5 files changed, 111 insertions(+), 57 deletions(-) rename Tests/PostgresNIOTests/New/Extensions/{ReverseChannelDecoder.swift => ReverseByteToMessageHandler.swift} (100%) create mode 100644 Tests/PostgresNIOTests/New/Extensions/ReverseMessageToByteHandler.swift diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index c1606497..4a7e7808 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -8,8 +8,8 @@ protocol PSQLChannelHandlerNotificationDelegate: AnyObject { } final class PSQLChannelHandler: ChannelDuplexHandler { - typealias InboundIn = PSQLBackendMessage typealias OutboundIn = PSQLTask + typealias InboundIn = ByteBuffer typealias OutboundOut = PSQLFrontendMessage private let logger: Logger @@ -24,6 +24,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { /// The context is captured in `handlerAdded` and released` in `handlerRemoved` private var handlerContext: ChannelHandlerContext! private var rowStream: PSQLRowStream? + private var decoder: NIOSingleStepByteToMessageProcessor private let authentificationConfiguration: PSQLConnection.Configuration.Authentication? private let configureSSLCallback: ((Channel) throws -> Void)? @@ -38,6 +39,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.authentificationConfiguration = authentification self.configureSSLCallback = configureSSLCallback self.logger = logger + self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder()) } #if DEBUG @@ -51,6 +53,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.authentificationConfiguration = authentification self.configureSSLCallback = configureSSLCallback self.logger = logger + self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder()) } #endif @@ -91,54 +94,62 @@ final class PSQLChannelHandler: ChannelDuplexHandler { } func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let incomingMessage = self.unwrapInboundIn(data) + let buffer = self.unwrapInboundIn(data) - self.logger.trace("Backend message received", metadata: [.message: "\(incomingMessage)"]) - - let action: ConnectionStateMachine.ConnectionAction - - switch incomingMessage { - case .authentication(let authentication): - action = self.state.authenticationMessageReceived(authentication) - case .backendKeyData(let keyData): - action = self.state.backendKeyDataReceived(keyData) - case .bindComplete: - action = self.state.bindCompleteReceived() - case .closeComplete: - action = self.state.closeCompletedReceived() - case .commandComplete(let commandTag): - action = self.state.commandCompletedReceived(commandTag) - case .dataRow(let dataRow): - action = self.state.dataRowReceived(dataRow) - case .emptyQueryResponse: - action = self.state.emptyQueryResponseReceived() - case .error(let errorResponse): - action = self.state.errorReceived(errorResponse) - case .noData: - action = self.state.noDataReceived() - case .notice(let noticeResponse): - action = self.state.noticeReceived(noticeResponse) - case .notification(let notification): - action = self.state.notificationReceived(notification) - case .parameterDescription(let parameterDescription): - action = self.state.parameterDescriptionReceived(parameterDescription) - case .parameterStatus(let parameterStatus): - action = self.state.parameterStatusReceived(parameterStatus) - case .parseComplete: - action = self.state.parseCompleteReceived() - case .portalSuspended: - action = self.state.portalSuspendedReceived() - case .readyForQuery(let transactionState): - action = self.state.readyForQueryReceived(transactionState) - case .rowDescription(let rowDescription): - action = self.state.rowDescriptionReceived(rowDescription) - case .sslSupported: - action = self.state.sslSupportedReceived() - case .sslUnsupported: - action = self.state.sslUnsupportedReceived() + do { + try self.decoder.process(buffer: buffer) { message in + self.logger.trace("Backend message received", metadata: [.message: "\(message)"]) + let action: ConnectionStateMachine.ConnectionAction + + switch message { + case .authentication(let authentication): + action = self.state.authenticationMessageReceived(authentication) + case .backendKeyData(let keyData): + action = self.state.backendKeyDataReceived(keyData) + case .bindComplete: + action = self.state.bindCompleteReceived() + case .closeComplete: + action = self.state.closeCompletedReceived() + case .commandComplete(let commandTag): + action = self.state.commandCompletedReceived(commandTag) + case .dataRow(let dataRow): + action = self.state.dataRowReceived(dataRow) + case .emptyQueryResponse: + action = self.state.emptyQueryResponseReceived() + case .error(let errorResponse): + action = self.state.errorReceived(errorResponse) + case .noData: + action = self.state.noDataReceived() + case .notice(let noticeResponse): + action = self.state.noticeReceived(noticeResponse) + case .notification(let notification): + action = self.state.notificationReceived(notification) + case .parameterDescription(let parameterDescription): + action = self.state.parameterDescriptionReceived(parameterDescription) + case .parameterStatus(let parameterStatus): + action = self.state.parameterStatusReceived(parameterStatus) + case .parseComplete: + action = self.state.parseCompleteReceived() + case .portalSuspended: + action = self.state.portalSuspendedReceived() + case .readyForQuery(let transactionState): + action = self.state.readyForQueryReceived(transactionState) + case .rowDescription(let rowDescription): + action = self.state.rowDescriptionReceived(rowDescription) + case .sslSupported: + action = self.state.sslSupportedReceived() + case .sslUnsupported: + action = self.state.sslUnsupportedReceived() + } + + self.run(action, with: context) + } + } catch let error as PSQLDecodingError { + let action = self.state.errorHappened(.decoding(error)) + self.run(action, with: context) + } catch { + preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.") } - - self.run(action, with: context) } func channelReadComplete(context: ChannelHandlerContext) { diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 54c58aee..ad5620aa 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -214,8 +214,6 @@ final class PSQLConnection { }.flatMap { address -> EventLoopFuture in let bootstrap = ClientBootstrap(group: eventLoop) .channelInitializer { channel in - let decoder = ByteToMessageHandler(PSQLBackendMessageDecoder()) - var configureSSLCallback: ((Channel) throws -> ())? = nil if let tlsConfiguration = configuration.tlsConfiguration { configureSSLCallback = { channel in @@ -225,12 +223,11 @@ final class PSQLConnection { let sslHandler = try NIOSSLClientHandler( context: sslContext, serverHostname: configuration.sslServerHostname) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(decoder)) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) } } return channel.pipeline.addHandlers([ - decoder, MessageToByteHandler(PSQLFrontendMessageEncoder(jsonEncoder: configuration.coders.jsonEncoder)), PSQLChannelHandler( authentification: configuration.authentication, diff --git a/Tests/PostgresNIOTests/New/Extensions/ReverseChannelDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/ReverseByteToMessageHandler.swift similarity index 100% rename from Tests/PostgresNIOTests/New/Extensions/ReverseChannelDecoder.swift rename to Tests/PostgresNIOTests/New/Extensions/ReverseByteToMessageHandler.swift diff --git a/Tests/PostgresNIOTests/New/Extensions/ReverseMessageToByteHandler.swift b/Tests/PostgresNIOTests/New/Extensions/ReverseMessageToByteHandler.swift new file mode 100644 index 00000000..135c881d --- /dev/null +++ b/Tests/PostgresNIOTests/New/Extensions/ReverseMessageToByteHandler.swift @@ -0,0 +1,32 @@ +import NIOCore + +/// This is a reverse ``NIOCore/ByteToMessageHandler``. Instead of creating messages from incoming bytes +/// as the normal `ByteToMessageHandler` does, this `ReverseByteToMessageHandler` creates messages +/// from outgoing bytes. This is only important for testing in `EmbeddedChannel`s. +class ReverseMessageToByteHandler: ChannelInboundHandler { + typealias InboundIn = Encoder.OutboundIn + typealias InboundOut = ByteBuffer + + var byteBuffer: ByteBuffer! + let encoder: Encoder + + init(_ encoder: Encoder) { + self.encoder = encoder + } + + func handlerAdded(context: ChannelHandlerContext) { + self.byteBuffer = context.channel.allocator.buffer(capacity: 128) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let message = self.unwrapInboundIn(data) + + do { + self.byteBuffer.clear() + try self.encoder.encode(data: message, out: &self.byteBuffer) + context.fireChannelRead(self.wrapInboundOut(self.byteBuffer)) + } catch { + context.fireErrorCaught(error) + } + } +} diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index b0456d49..878e51c7 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -11,8 +11,10 @@ class PSQLChannelHandlerTests: XCTestCase { func testHandlerAddedWithoutSSL() { let config = self.testConnectionConfiguration() - let handler = PSQLChannelHandler(authentification: config.authentication, configureSSLCallback: nil) - let embedded = EmbeddedChannel(handler: handler) + let embedded = EmbeddedChannel(handlers: [ + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + PSQLChannelHandler(authentification: config.authentication, configureSSLCallback: nil) + ]) defer { XCTAssertNoThrow(try embedded.finish()) } var maybeMessage: PSQLFrontendMessage? @@ -39,7 +41,10 @@ class PSQLChannelHandlerTests: XCTestCase { let handler = PSQLChannelHandler(authentification: config.authentication) { channel in addSSLCallbackIsHit = true } - let embedded = EmbeddedChannel(handler: handler) + let embedded = EmbeddedChannel(handlers: [ + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + handler + ]) var maybeMessage: PSQLFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) @@ -78,7 +83,10 @@ class PSQLChannelHandlerTests: XCTestCase { XCTFail("This callback should never be exectuded") throw PSQLError.sslUnsupported } - let embedded = EmbeddedChannel(handler: handler) + let embedded = EmbeddedChannel(handlers: [ + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + handler + ]) let eventHandler = TestEventHandler() XCTAssertNoThrow(try embedded.pipeline.addHandler(eventHandler, position: .last).wait()) @@ -107,7 +115,10 @@ class PSQLChannelHandlerTests: XCTestCase { ) let state = ConnectionStateMachine(.waitingToStartAuthentication) let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil) - let embedded = EmbeddedChannel(handler: handler) + let embedded = EmbeddedChannel(handlers: [ + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + handler + ]) embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) XCTAssertEqual(try embedded.readOutbound(as: PSQLFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) @@ -132,7 +143,10 @@ class PSQLChannelHandlerTests: XCTestCase { ) let state = ConnectionStateMachine(.waitingToStartAuthentication) let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil) - let embedded = EmbeddedChannel(handler: handler) + let embedded = EmbeddedChannel(handlers: [ + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + handler + ]) embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) XCTAssertEqual(try embedded.readOutbound(as: PSQLFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) From 131deb3f1d3f20362c60fa5363339c3f5649cb8f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 23 Sep 2021 11:36:09 +0200 Subject: [PATCH 025/292] Move message encoding into PSQLChannelHandler (#181) --- .../PostgresNIO/New/PSQLChannelHandler.swift | 109 +++++++++++------- Sources/PostgresNIO/New/PSQLConnection.swift | 3 +- .../PSQLFrontendMessageDecoder.swift | 6 +- .../New/PSQLChannelHandlerTests.swift | 16 ++- 4 files changed, 80 insertions(+), 54 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index 4a7e7808..e4d38687 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -10,7 +10,7 @@ protocol PSQLChannelHandlerNotificationDelegate: AnyObject { final class PSQLChannelHandler: ChannelDuplexHandler { typealias OutboundIn = PSQLTask typealias InboundIn = ByteBuffer - typealias OutboundOut = PSQLFrontendMessage + typealias OutboundOut = ByteBuffer private let logger: Logger private var state: ConnectionStateMachine { @@ -25,18 +25,19 @@ final class PSQLChannelHandler: ChannelDuplexHandler { private var handlerContext: ChannelHandlerContext! private var rowStream: PSQLRowStream? private var decoder: NIOSingleStepByteToMessageProcessor - private let authentificationConfiguration: PSQLConnection.Configuration.Authentication? + private var encoder: BufferedMessageEncoder! + private let configuration: PSQLConnection.Configuration private let configureSSLCallback: ((Channel) throws -> Void)? /// this delegate should only be accessed on the connections `EventLoop` weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate? - init(authentification: PSQLConnection.Configuration.Authentication?, + init(configuration: PSQLConnection.Configuration, logger: Logger, configureSSLCallback: ((Channel) throws -> Void)?) { self.state = ConnectionStateMachine() - self.authentificationConfiguration = authentification + self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder()) @@ -44,13 +45,13 @@ final class PSQLChannelHandler: ChannelDuplexHandler { #if DEBUG /// for testing purposes only - init(authentification: PSQLConnection.Configuration.Authentication?, + init(configuration: PSQLConnection.Configuration, state: ConnectionStateMachine = .init(.initialized), logger: Logger = .psqlNoOpLogger, configureSSLCallback: ((Channel) throws -> Void)?) { self.state = state - self.authentificationConfiguration = authentification + self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder()) @@ -61,6 +62,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler { func handlerAdded(context: ChannelHandlerContext) { self.handlerContext = context + self.encoder = BufferedMessageEncoder( + buffer: context.channel.allocator.buffer(capacity: 256), + encoder: PSQLFrontendMessageEncoder(jsonEncoder: self.configuration.coders.jsonEncoder) + ) + if context.channel.isActive { self.connected(context: context) } @@ -222,15 +228,19 @@ final class PSQLChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - context.writeAndFlush(.startup(.versionThree(parameters: authContext.toStartupParameters())), promise: nil) + try! self.encoder.encode(.startup(.versionThree(parameters: authContext.toStartupParameters()))) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .sendSSLRequest: - context.writeAndFlush(.sslRequest(.init()), promise: nil) + try! self.encoder.encode(.sslRequest(.init())) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .sendPasswordMessage(let mode, let authContext): self.sendPasswordMessage(mode: mode, authContext: authContext, context: context) case .sendSaslInitialResponse(let name, let initialResponse): - context.writeAndFlush(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse))) + try! self.encoder.encode(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse))) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .sendSaslResponse(let bytes): - context.writeAndFlush(.saslResponse(.init(data: bytes))) + try! self.encoder.encode(.saslResponse(.init(data: bytes))) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .closeConnectionAndCleanup(let cleanupContext): self.closeConnectionAndCleanup(cleanupContext, context: context) case .fireChannelInactive: @@ -277,7 +287,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { case .provideAuthenticationContext: context.fireUserInboundEventTriggered(PSQLEvent.readyForStartup) - if let authentication = self.authentificationConfiguration { + if let authentication = self.configuration.authentication { let authContext = AuthContext( username: authentication.username, password: authentication.password, @@ -293,7 +303,8 @@ final class PSQLChannelHandler: ChannelDuplexHandler { // The normal, graceful termination procedure is that the frontend sends a Terminate // message and immediately closes the connection. On receipt of this message, the // backend closes the connection and terminates. - context.write(.terminate, promise: nil) + try! self.encoder.encode(.terminate) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } context.close(mode: .all, promise: promise) case .succeedPreparedStatementCreation(let preparedContext, with: let rowDescription): @@ -357,22 +368,26 @@ final class PSQLChannelHandler: ChannelDuplexHandler { hash2.append(salt.3) let hash = "md5" + Insecure.MD5.hash(data: hash2).hexdigest() - context.writeAndFlush(.password(.init(value: hash)), promise: nil) + try! self.encoder.encode(.password(.init(value: hash))) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + case .cleartext: - context.writeAndFlush(.password(.init(value: authContext.password ?? "")), promise: nil) + try! self.encoder.encode(.password(.init(value: authContext.password ?? ""))) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } } private func sendCloseAndSyncMessage(_ sendClose: CloseTarget, context: ChannelHandlerContext) { switch sendClose { case .preparedStatement(let name): - context.write(.close(.preparedStatement(name)), promise: nil) - context.write(.sync, promise: nil) - context.flush() + try! self.encoder.encode(.close(.preparedStatement(name))) + try! self.encoder.encode(.sync) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + case .portal(let name): - context.write(.close(.portal(name)), promise: nil) - context.write(.sync, promise: nil) - context.flush() + try! self.encoder.encode(.close(.portal(name))) + try! self.encoder.encode(.sync) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } } @@ -387,10 +402,16 @@ final class PSQLChannelHandler: ChannelDuplexHandler { query: query, parameters: []) - context.write(.parse(parse), promise: nil) - context.write(.describe(.preparedStatement(statementName)), promise: nil) - context.write(.sync, promise: nil) - context.flush() + + do { + try self.encoder.encode(.parse(parse)) + try self.encoder.encode(.describe(.preparedStatement(statementName))) + try self.encoder.encode(.sync) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + } catch { + let action = self.state.errorHappened(.channel(underlying: error)) + self.run(action, with: context) + } } private func sendBindExecuteAndSyncMessage( @@ -403,10 +424,15 @@ final class PSQLChannelHandler: ChannelDuplexHandler { preparedStatementName: statementName, parameters: binds) - context.write(.bind(bind), promise: nil) - context.write(.execute(.init(portalName: "")), promise: nil) - context.write(.sync, promise: nil) - context.flush() + do { + try self.encoder.encode(.bind(bind)) + try self.encoder.encode(.execute(.init(portalName: ""))) + try self.encoder.encode(.sync) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + } catch { + let action = self.state.errorHappened(.channel(underlying: error)) + self.run(action, with: context) + } } private func sendParseDescribeBindExecuteAndSyncMessage( @@ -424,12 +450,17 @@ final class PSQLChannelHandler: ChannelDuplexHandler { preparedStatementName: unnamedStatementName, parameters: binds) - context.write(wrapOutboundOut(.parse(parse)), promise: nil) - context.write(wrapOutboundOut(.describe(.preparedStatement(""))), promise: nil) - context.write(wrapOutboundOut(.bind(bind)), promise: nil) - context.write(wrapOutboundOut(.execute(.init(portalName: ""))), promise: nil) - context.write(wrapOutboundOut(.sync), promise: nil) - context.flush() + do { + try self.encoder.encode(.parse(parse)) + try self.encoder.encode(.describe(.preparedStatement(""))) + try self.encoder.encode(.bind(bind)) + try self.encoder.encode(.execute(.init(portalName: ""))) + try self.encoder.encode(.sync) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + } catch { + let action = self.state.errorHappened(.channel(underlying: error)) + self.run(action, with: context) + } } private func succeedQueryWithRowStream( @@ -503,16 +534,6 @@ extension PSQLChannelHandler: PSQLRowsDataSource { } } -extension ChannelHandlerContext { - func write(_ psqlMessage: PSQLFrontendMessage, promise: EventLoopPromise? = nil) { - self.write(NIOAny(psqlMessage), promise: promise) - } - - func writeAndFlush(_ psqlMessage: PSQLFrontendMessage, promise: EventLoopPromise? = nil) { - self.writeAndFlush(NIOAny(psqlMessage), promise: promise) - } -} - extension PSQLConnection.Configuration.Authentication { func toAuthContext() -> AuthContext { AuthContext( diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index ad5620aa..d6c31542 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -228,9 +228,8 @@ final class PSQLConnection { } return channel.pipeline.addHandlers([ - MessageToByteHandler(PSQLFrontendMessageEncoder(jsonEncoder: configuration.coders.jsonEncoder)), PSQLChannelHandler( - authentification: configuration.authentication, + configuration: configuration, logger: logger, configureSSLCallback: configureSSLCallback), PSQLEventsHandler(logger: logger) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 79e56507..c639f4b2 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -20,13 +20,13 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { return nil } - guard var messageSlice = buffer.getSlice(at: buffer.readerIndex &+ 4, length: Int(length)) else { + guard var messageSlice = buffer.getSlice(at: buffer.readerIndex + 4, length: Int(length) - 4) else { return nil } - buffer.moveReaderIndex(forwardBy: 4 &+ Int(length)) + buffer.moveReaderIndex(to: Int(length)) let finalIndex = buffer.readerIndex - guard let code = buffer.readInteger(as: UInt32.self) else { + guard let code = messageSlice.readInteger(as: UInt32.self) else { throw PSQLPartialDecodingError.fieldNotDecodable(type: UInt32.self) } diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index 878e51c7..a9bfb228 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -11,9 +11,11 @@ class PSQLChannelHandlerTests: XCTestCase { func testHandlerAddedWithoutSSL() { let config = self.testConnectionConfiguration() + let handler = PSQLChannelHandler(configuration: config, configureSSLCallback: nil) let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), - PSQLChannelHandler(authentification: config.authentication, configureSSLCallback: nil) + handler ]) defer { XCTAssertNoThrow(try embedded.finish()) } @@ -38,10 +40,11 @@ class PSQLChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() config.tlsConfiguration = .makeClientConfiguration() var addSSLCallbackIsHit = false - let handler = PSQLChannelHandler(authentification: config.authentication) { channel in + let handler = PSQLChannelHandler(configuration: config) { channel in addSSLCallbackIsHit = true } let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), handler ]) @@ -79,11 +82,12 @@ class PSQLChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() config.tlsConfiguration = .makeClientConfiguration() - let handler = PSQLChannelHandler(authentification: config.authentication) { channel in + let handler = PSQLChannelHandler(configuration: config) { channel in XCTFail("This callback should never be exectuded") throw PSQLError.sslUnsupported } let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), handler ]) @@ -114,8 +118,9 @@ class PSQLChannelHandlerTests: XCTestCase { database: config.authentication?.database ) let state = ConnectionStateMachine(.waitingToStartAuthentication) - let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil) + let handler = PSQLChannelHandler(configuration: config, state: state, configureSSLCallback: nil) let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), handler ]) @@ -142,8 +147,9 @@ class PSQLChannelHandlerTests: XCTestCase { database: config.authentication?.database ) let state = ConnectionStateMachine(.waitingToStartAuthentication) - let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil) + let handler = PSQLChannelHandler(configuration: config, state: state, configureSSLCallback: nil) let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), handler ]) From 046d3ba1a40c5d2e2457177f4311b0b9c97f8945 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 23 Sep 2021 20:08:23 +0200 Subject: [PATCH 026/292] Bump SwiftNIO dependency (#184) --- Package.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index 97966ead..64c261b3 100644 --- a/Package.swift +++ b/Package.swift @@ -13,7 +13,7 @@ let package = Package( .library(name: "PostgresNIO", targets: ["PostgresNIO"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.32.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), From c5eda6cebfdb81959f96d4ad8fcb1e8fc4596a52 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 21 Oct 2021 09:12:49 +0200 Subject: [PATCH 027/292] Make password hashing fast (#189) --- .../PostgresNIO/New/PSQLChannelHandler.swift | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index e4d38687..20f3c065 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -357,16 +357,16 @@ final class PSQLChannelHandler: ChannelDuplexHandler { switch mode { case .md5(let salt): let hash1 = (authContext.password ?? "") + authContext.username - let pwdhash = Insecure.MD5.hash(data: [UInt8](hash1.utf8)).hexdigest() - + let pwdhash = Insecure.MD5.hash(data: [UInt8](hash1.utf8)).asciiHexDigest() + var hash2 = [UInt8]() hash2.reserveCapacity(pwdhash.count + 4) - hash2.append(contentsOf: pwdhash.utf8) + hash2.append(contentsOf: pwdhash) hash2.append(salt.0) hash2.append(salt.1) hash2.append(salt.2) hash2.append(salt.3) - let hash = "md5" + Insecure.MD5.hash(data: hash2).hexdigest() + let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest() try! self.encoder.encode(.password(.init(value: hash))) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) @@ -552,3 +552,40 @@ extension AuthContext { replication: .false) } } + +private extension Insecure.MD5.Digest { + + private static let lowercaseLookup: [UInt8] = [ + UInt8(ascii: "0"), UInt8(ascii: "1"), UInt8(ascii: "2"), UInt8(ascii: "3"), + UInt8(ascii: "4"), UInt8(ascii: "5"), UInt8(ascii: "6"), UInt8(ascii: "7"), + UInt8(ascii: "8"), UInt8(ascii: "9"), UInt8(ascii: "a"), UInt8(ascii: "b"), + UInt8(ascii: "c"), UInt8(ascii: "d"), UInt8(ascii: "e"), UInt8(ascii: "f"), + ] + + func asciiHexDigest() -> [UInt8] { + var result = [UInt8]() + result.reserveCapacity(2 * Insecure.MD5Digest.byteCount) + for byte in self { + result.append(Self.lowercaseLookup[Int(byte >> 4)]) + result.append(Self.lowercaseLookup[Int(byte & 0x0F)]) + } + return result + } + + func md5PrefixHexdigest() -> String { + // TODO: The array should be stack allocated in the best case. But we support down to 5.2. + // Given that this method is called only on startup of a new connection, this is an + // okay tradeoff for now. + var result = [UInt8]() + result.reserveCapacity(3 + 2 * Insecure.MD5Digest.byteCount) + result.append(UInt8(ascii: "m")) + result.append(UInt8(ascii: "d")) + result.append(UInt8(ascii: "5")) + + for byte in self { + result.append(Self.lowercaseLookup[Int(byte >> 4)]) + result.append(Self.lowercaseLookup[Int(byte & 0x0F)]) + } + return String(decoding: result, as: Unicode.UTF8.self) + } +} From 24d84237fa9f2cb86b385ef97f74a4bf84f35627 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 11 Nov 2021 13:58:11 -0600 Subject: [PATCH 028/292] Rewire CI a bit (#191) * Rewire CI a bit Use a better test matrix, add code coverage collection from the unit tests, disable TSan for Concurrency reasons, use rpath workaround for macOS tests, match things up in general with postgres-kit's CI, keep unit tests separate from integration so they don't repeatedly run pointlessly for different DB versions etc. --- .github/workflows/test.yml | 199 +++++++++++++++++-------------------- 1 file changed, 89 insertions(+), 110 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 28bd3784..16c59274 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,91 +4,42 @@ env: LOG_LEVEL: notice jobs: - - # Test that packages depending on us still work - dependents: - strategy: - fail-fast: false - matrix: - swiftver: - - 5.2 - - 5.3 - - 5.4 - dbimage: - - postgres:13 - - postgres:12 - - postgres:11 - dependent: - - postgres-kit - - fluent-postgres-driver - container: swift:${{ matrix.swiftver }}-focal - runs-on: ubuntu-latest - services: - psql-a: - image: ${{ matrix.dbimage }} - env: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - psql-b: - image: ${{ matrix.dbimage }} - env: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password - steps: - - name: Check out package - uses: actions/checkout@v2 - with: - path: package - - name: Check out dependent - uses: actions/checkout@v2 - with: - repository: vapor/${{ matrix.dependent }} - path: dependent - - name: Use local package - run: swift package edit postgres-nio --path ../package - working-directory: dependent - - name: Run tests with Thread Sanitizer - run: swift test --enable-test-discovery --sanitize=thread - working-directory: dependent - env: - POSTGRES_HOSTNAME: psql-a - POSTGRES_HOSTNAME_A: psql-a - POSTGRES_HOSTNAME_B: psql-b - - # Run unit tests on Linux Swift runners on - linux-unit-tests: + linux-unit: strategy: fail-fast: false matrix: swiftver: - swift:5.2 - - swift:5.3 - - swift:5.4 - - swiftlang/swift:nightly-5.5 + - swift:5.5 - swiftlang/swift:nightly-main swiftos: - #- xenial - #- bionic - focal - #- centos7 - #- centos8 - #- amazonlinux2 container: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} runs-on: ubuntu-latest + env: + LOG_LEVEL: debug + MATRIX_CONFIG: ${{ toJSON(matrix) }} steps: - - name: Check out code + - name: Check out package uses: actions/checkout@v2 - - name: Run tests with Thread Sanitizer - run: swift test --enable-test-discovery --sanitize=thread --filter=^PostgresNIOTests + - name: Run unit tests with code coverage + run: | + swift test --enable-test-discovery --filter=^PostgresNIOTests --enable-code-coverage && \ + echo "CODECOV_FILE=$(swift test --show-codecov-path)" >> $GITHUB_ENV + - name: Send coverage report to codecov.io + uses: codecov/codecov-action@v2 + with: + files: ${{ env.CODECOV_FILE }} + flags: 'unittests' + env_vars: 'MATRIX_CONFIG' + fail_ci_if_error: true - # Run integration tests on Linux Swift runners against supported PSQL versions - linux-integration-tests: + linux-integration-and-dependencies: strategy: fail-fast: false matrix: dbimage: + - postgres:14 - postgres:13 - postgres:12 - postgres:11 @@ -97,54 +48,86 @@ jobs: - md5 - scram-sha-256 swiftver: - - swift:5.4 + - swift:5.2 + - swift:5.5 + - swiftlang/swift:nightly-main swiftos: - #- xenial - #- bionic - focal - #- centos7 - #- centos8 - #- amazonlinux2 container: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} runs-on: ubuntu-latest + env: + LOG_LEVEL: debug + POSTGRES_HOSTNAME: 'psql-a' + POSTGRES_DB: 'vapor_database' + POSTGRES_USER: 'vapor_username' + POSTGRES_PASSWORD: 'vapor_password' + POSTGRES_HOSTNAME_A: 'psql-a' + POSTGRES_HOSTNAME_B: 'psql-b' + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} services: - psql: + psql-a: image: ${{ matrix.dbimage }} env: - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password + POSTGRES_USER: 'vapor_username' + POSTGRES_DB: 'vapor_database' + POSTGRES_PASSWORD: 'vapor_password' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} - steps: - - name: Check out code - uses: actions/checkout@v2 - - name: Run tests with Thread Sanitizer - run: swift test --enable-test-discovery --sanitize=thread --filter=^IntegrationTests + psql-b: + image: ${{ matrix.dbimage }} env: - POSTGRES_HOSTNAME: psql - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password + POSTGRES_USER: 'vapor_username' + POSTGRES_DB: 'vapor_database' + POSTGRES_PASSWORD: 'vapor_password' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} + POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} + steps: + - name: Check out package + uses: actions/checkout@v2 + with: { path: 'postgres-nio' } + - name: Run integration tests + run: swift test --package-path postgres-nio --enable-test-discovery --filter=^IntegrationTests + - name: Check out postgres-kit dependent + uses: actions/checkout@v2 + with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } + - name: Check out fluent-postgres-driver dependent + uses: actions/checkout@v2 + with: { repository: 'vapor/fluent-postgres-driver', path: 'fluent-postgres-driver' } + - name: Use local package in dependents + run: | + swift package --package-path postgres-kit edit postgres-nio --path postgres-nio + swift package --package-path fluent-postgres-driver edit postgres-nio --path postgres-nio + - name: Run postgres-kit tests + run: swift test --package-path postgres-kit --enable-test-discovery + - name: Run fluent-postgres-driver tests + run: swift test --package-path fluent-postgres-driver --enable-test-discovery - # Run package tests on macOS against supported PSQL versions - macos: + macos-all: strategy: fail-fast: false matrix: + dbimage: + # Only test the lastest couple of versions on macOS, let Linux do the rest + - postgresql@14 + - postgresql@13 + # - postgresql@12 + # - postgresql@11 + dbauth: + # Only test one auth method on macOS, Linux tests will cover the others + # - trust + # - md5 + - scram-sha-256 xcode: - latest-stable - latest - dbauth: - - trust - - md5 - - scram-sha-256 - formula: - - postgresql@11 - - postgresql@12 - - postgresql@13 - runs-on: macos-latest + runs-on: macos-11 + env: + LOG_LEVEL: debug + POSTGRES_HOSTNAME: 127.0.0.1 + POSTGRES_USER: 'vapor_username' + POSTGRES_PASSWORD: 'vapor_password' + POSTGRES_DB: 'postgres' + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} steps: - name: Select latest available Xcode uses: maxim-lobanov/setup-xcode@v1 @@ -152,18 +135,14 @@ jobs: xcode-version: ${{ matrix.xcode }} - name: Install Postgres, setup DB and auth, and wait for server start run: | - export PATH="/usr/local/opt/${{ matrix.formula }}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - brew install ${{ matrix.formula }} - initdb --locale=C --auth-host ${{ matrix.dbauth }} -U vapor_username --pwfile=<(echo vapor_password) + export PATH="$(brew prefix)/opt/${{ matrix.dbimage }}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test + brew install ${{ matrix.dbimage }} + initdb --locale=C --auth-host ${{ matrix.dbauth }} -U $POSTGRES_USER --pwfile=<(echo $POSTGRES_PASSWORD) pg_ctl start --wait - timeout-minutes: 5 + timeout-minutes: 2 - name: Checkout code uses: actions/checkout@v2 - - name: Run tests with Thread Sanitizer - run: swift test --enable-test-discovery --sanitize=thread - env: - POSTGRES_HOSTNAME: 127.0.0.1 - POSTGRES_USER: vapor_username - POSTGRES_DB: postgres - POSTGRES_PASSWORD: vapor_password - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} + - name: Run all tests + run: | + swift test --enable-test-discovery -Xlinker -rpath \ + -Xlinker $(xcode-select -p)/Toolchains/XcodeDefault.xctoolchain/usr/lib/swift-5.5/macosx From 4dcec4eca4708f7705f96d4704ef783199b13795 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 13 Nov 2021 08:13:43 +0100 Subject: [PATCH 029/292] Prefix default usernames and password for test with test_ instead of vapor_ (#192) * Prefix default usernames and password for test with test_ instead of vapor_ * Update test workflow to use new defaults --- .github/workflows/test.yml | 31 ++++++++++++------- README.md | 6 ++-- .../PSQLIntegrationTests.swift | 10 +++--- Tests/IntegrationTests/Utilities.swift | 6 ++-- docker-compose.yml | 18 +++++------ 5 files changed, 40 insertions(+), 31 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 16c59274..76910591 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -57,10 +57,19 @@ jobs: runs-on: ubuntu-latest env: LOG_LEVEL: debug + # Unfortunately, fluent-postgres-driver details leak through here POSTGRES_HOSTNAME: 'psql-a' - POSTGRES_DB: 'vapor_database' - POSTGRES_USER: 'vapor_username' - POSTGRES_PASSWORD: 'vapor_password' + POSTGRES_DB: 'test_database' + POSTGRES_DATABASE: 'test_database' + POSTGRES_DATABASE_A: 'test_database' + POSTGRES_DATABASE_B: 'test_database' + POSTGRES_USER: 'test_username' + POSTGRES_USERNAME: 'test_username' + POSTGRES_USERNAME_A: 'test_username' + POSTGRES_USERNAME_B: 'test_username' + POSTGRES_PASSWORD: 'test_password' + POSTGRES_PASSWORD_A: 'test_password' + POSTGRES_PASSWORD_B: 'test_password' POSTGRES_HOSTNAME_A: 'psql-a' POSTGRES_HOSTNAME_B: 'psql-b' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} @@ -68,17 +77,17 @@ jobs: psql-a: image: ${{ matrix.dbimage }} env: - POSTGRES_USER: 'vapor_username' - POSTGRES_DB: 'vapor_database' - POSTGRES_PASSWORD: 'vapor_password' + POSTGRES_USER: 'test_username' + POSTGRES_DB: 'test_database' + POSTGRES_PASSWORD: 'test_password' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} psql-b: image: ${{ matrix.dbimage }} env: - POSTGRES_USER: 'vapor_username' - POSTGRES_DB: 'vapor_database' - POSTGRES_PASSWORD: 'vapor_password' + POSTGRES_USER: 'test_username' + POSTGRES_DB: 'test_database' + POSTGRES_PASSWORD: 'test_password' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} steps: @@ -124,8 +133,8 @@ jobs: env: LOG_LEVEL: debug POSTGRES_HOSTNAME: 127.0.0.1 - POSTGRES_USER: 'vapor_username' - POSTGRES_PASSWORD: 'vapor_password' + POSTGRES_USER: 'test_username' + POSTGRES_PASSWORD: 'test_password' POSTGRES_DB: 'postgres' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} steps: diff --git a/README.md b/README.md index 99530a85..c99ae224 100644 --- a/README.md +++ b/README.md @@ -124,9 +124,9 @@ Once you have a connection, you will need to authenticate with the server using ```swift try conn.authenticate( - username: "vapor_username", - database: "vapor_database", - password: "vapor_password" + username: "your_username", + database: "your_database", + password: "your_password" ).wait() ``` diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 011d8c70..dabe9f1c 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -25,8 +25,8 @@ final class IntegrationTests: XCTestCase { let config = PSQLConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432, - username: env("POSTGRES_USER") ?? "postgres", - database: env("POSTGRES_DB"), + username: env("POSTGRES_USER") ?? "test_username", + database: env("POSTGRES_DB") ?? "test_database", password: "wrong_password", tlsConfiguration: nil) @@ -327,9 +327,9 @@ extension PSQLConnection { let config = PSQLConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432, - username: env("POSTGRES_USER") ?? "postgres", - database: env("POSTGRES_DB"), - password: env("POSTGRES_PASSWORD"), + username: env("POSTGRES_USER") ?? "test_username", + database: env("POSTGRES_DB") ?? "test_database", + password: env("POSTGRES_PASSWORD") ?? "test_password", tlsConfiguration: nil) return PSQLConnection.connect(configuration: config, logger: logger, on: eventLoop) diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index 0964f947..070122d1 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -26,9 +26,9 @@ extension PostgresConnection { static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { return testUnauthenticated(on: eventLoop, logLevel: logLevel).flatMap { conn in return conn.authenticate( - username: env("POSTGRES_USER") ?? "vapor_username", - database: env("POSTGRES_DB") ?? "vapor_database", - password: env("POSTGRES_PASSWORD") ?? "vapor_password" + username: env("POSTGRES_USER") ?? "test_username", + database: env("POSTGRES_DB") ?? "test_database", + password: env("POSTGRES_PASSWORD") ?? "test_password" ).map { return conn }.flatMapError { error in diff --git a/docker-compose.yml b/docker-compose.yml index b3dc6e61..06b46dc9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,9 +6,9 @@ services: user: postgres:postgres environment: POSTGRES_HOST_AUTH_METHOD: "${POSTGRES_HOST_AUTH_METHOD:-md5}" - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password + POSTGRES_USER: test_username + POSTGRES_DB: test_database + POSTGRES_PASSWORD: test_password ports: - 5432:5432 psql-12: @@ -16,9 +16,9 @@ services: user: postgres:postgres environment: POSTGRES_HOST_AUTH_METHOD: "${POSTGRES_HOST_AUTH_METHOD:-md5}" - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password + POSTGRES_USER: test_username + POSTGRES_DB: test_database + POSTGRES_PASSWORD: test_password ports: - 5432:5432 psql-11: @@ -26,8 +26,8 @@ services: user: postgres:postgres environment: POSTGRES_HOST_AUTH_METHOD: "${POSTGRES_HOST_AUTH_METHOD:-md5}" - POSTGRES_USER: vapor_username - POSTGRES_DB: vapor_database - POSTGRES_PASSWORD: vapor_password + POSTGRES_USER: test_username + POSTGRES_DB: test_database + POSTGRES_PASSWORD: test_password ports: - 5432:5432 From f876692dec69fc6b5cf46a3c06c4a3ed737899a7 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sat, 13 Nov 2021 02:36:16 -0600 Subject: [PATCH 030/292] update docker-compose.yml just in case anyone's using it (#193) --- docker-compose.yml | 39 +++++++++++++++------------------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 06b46dc9..600bdc99 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,33 +1,24 @@ version: '3.7' +x-shared-config: &shared_config + environment: + POSTGRES_HOST_AUTH_METHOD: "${POSTGRES_HOST_AUTH_METHOD:-scram-sha-256}" + POSTGRES_USER: test_username + POSTGRES_DB: test_database + POSTGRES_PASSWORD: test_password + ports: + - 5432:5432 + services: + psql-14: + image: postgres:14 + <<: *shared_config psql-13: image: postgres:13 - user: postgres:postgres - environment: - POSTGRES_HOST_AUTH_METHOD: "${POSTGRES_HOST_AUTH_METHOD:-md5}" - POSTGRES_USER: test_username - POSTGRES_DB: test_database - POSTGRES_PASSWORD: test_password - ports: - - 5432:5432 + <<: *shared_config psql-12: image: postgres:12 - user: postgres:postgres - environment: - POSTGRES_HOST_AUTH_METHOD: "${POSTGRES_HOST_AUTH_METHOD:-md5}" - POSTGRES_USER: test_username - POSTGRES_DB: test_database - POSTGRES_PASSWORD: test_password - ports: - - 5432:5432 + <<: *shared_config psql-11: image: postgres:11 - user: postgres:postgres - environment: - POSTGRES_HOST_AUTH_METHOD: "${POSTGRES_HOST_AUTH_METHOD:-md5}" - POSTGRES_USER: test_username - POSTGRES_DB: test_database - POSTGRES_PASSWORD: test_password - ports: - - 5432:5432 + <<: *shared_config From 4041a690bae286297559452eb7178d91205b6eea Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 18 Nov 2021 06:56:19 -0600 Subject: [PATCH 031/292] Use LCOV code coverage format (#195) * Use LCOV code coverage format instead of the broken JSON Swift still generates by default --- .github/workflows/test.yml | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 76910591..bc4fd19e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,9 +23,18 @@ jobs: - name: Check out package uses: actions/checkout@v2 - name: Run unit tests with code coverage + run: swift test --enable-test-discovery --filter=^PostgresNIOTests --enable-code-coverage + - name: Convert code coverage report to most expressive format run: | - swift test --enable-test-discovery --filter=^PostgresNIOTests --enable-code-coverage && \ - echo "CODECOV_FILE=$(swift test --show-codecov-path)" >> $GITHUB_ENV + export pkgname="$(swift package dump-package | perl -e 'use JSON::PP; print (decode_json(join("",(<>)))->{name});')" \ + subpath="$([ "$(uname -s)" = 'Darwin' ] && echo "/Contents/MacOS/${pkgname}PackageTests" || true)" \ + exc_prefix="$(which xcrun || true)" && \ + ${exc_prefix} llvm-cov export -format lcov \ + -instr-profile="$(dirname "$(swift test --show-codecov-path)")/default.profdata" \ + --ignore-filename-regex='\.build/' \ + "$(swift build --show-bin-path)/${pkgname}PackageTests.xctest${subpath}" \ + >"${pkgname}.lcov" + echo "CODECOV_FILE=$(pwd)/${pkgname}.lcov" >> $GITHUB_ENV - name: Send coverage report to codecov.io uses: codecov/codecov-action@v2 with: From 81157411857c6b9dc34f0bd94677e355a1853775 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 18 Nov 2021 06:59:41 -0600 Subject: [PATCH 032/292] Quick fix to test workflow --- .github/workflows/test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bc4fd19e..d6317e14 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,7 +31,8 @@ jobs: exc_prefix="$(which xcrun || true)" && \ ${exc_prefix} llvm-cov export -format lcov \ -instr-profile="$(dirname "$(swift test --show-codecov-path)")/default.profdata" \ - --ignore-filename-regex='\.build/' \ + --ignore-filename-regex='/\.build/' \ + --ignore-filename-regex='/Tests/' \ "$(swift build --show-bin-path)/${pkgname}PackageTests.xctest${subpath}" \ >"${pkgname}.lcov" echo "CODECOV_FILE=$(pwd)/${pkgname}.lcov" >> $GITHUB_ENV From 549b17f880a28522f431904f803e89da72e6b680 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 18 Nov 2021 16:15:19 +0100 Subject: [PATCH 033/292] Move pre 1.5.0 unit tests into the unit test target (#196) To make sure the unit test codecov report is correct, move pre 1.5.0 unit tests into the unit test target. --- Tests/IntegrationTests/PostgresNIOTests.swift | 104 ------------------ .../Data/PostgresData+JSONTests.swift | 20 ++++ .../Message/PostgresMessageDecoderTests.swift | 37 +++++++ .../Utilities/PostgresJSONCodingTests.swift | 61 ++++++++++ 4 files changed, 118 insertions(+), 104 deletions(-) create mode 100644 Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift create mode 100644 Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift create mode 100644 Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 308ecfee..739735bb 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -659,22 +659,6 @@ final class PostgresNIOTests: XCTestCase { } } - func testJSONBConvertible() { - struct Object: PostgresJSONBCodable { - let foo: Int - let bar: Int - } - - XCTAssertEqual(Object.postgresDataType, .jsonb) - - let postgresData = Object(foo: 1, bar: 2).postgresData - XCTAssertEqual(postgresData?.type, .jsonb) - - let object = Object(postgresData: postgresData!) - XCTAssertEqual(object?.foo, 1) - XCTAssertEqual(object?.bar, 2) - } - func testRemoteTLSServer() { // postgres://uymgphwj:7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA@elmer.db.elephantsql.com:5432/uymgphwj var conn: PostgresConnection? @@ -899,38 +883,6 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(rows?.first?.column("int")?.uint8, 5) } - func testMessageDecoder() { - let sample: [UInt8] = [ - 0x52, // R - authentication - 0x00, 0x00, 0x00, 0x0C, // length = 12 - 0x00, 0x00, 0x00, 0x05, // md5 - 0x01, 0x02, 0x03, 0x04, // salt - 0x4B, // B - backend key data - 0x00, 0x00, 0x00, 0x0C, // length = 12 - 0x05, 0x05, 0x05, 0x05, // process id - 0x01, 0x01, 0x01, 0x01, // secret key - ] - var input = ByteBufferAllocator().buffer(capacity: 0) - input.writeBytes(sample) - - let output: [PostgresMessage] = [ - PostgresMessage(identifier: .authentication, bytes: [ - 0x00, 0x00, 0x00, 0x05, - 0x01, 0x02, 0x03, 0x04, - ]), - PostgresMessage(identifier: .backendKeyData, bytes: [ - 0x05, 0x05, 0x05, 0x05, - 0x01, 0x01, 0x01, 0x01, - ]) - ] - XCTAssertNoThrow(try XCTUnwrap(ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: [(input, output)], - decoderFactory: { - PostgresMessageDecoder() - } - ))) - } - func testPreparedQuery() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -1135,62 +1087,6 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(rows?.first?.column("min64")?.int64, .min) XCTAssertEqual(rows?.first?.column("max64")?.int64, .max) } - - // https://github.com/vapor/postgres-nio/issues/126 - func testCustomJSONEncoder() { - let previousDefaultJSONEncoder = PostgresNIO._defaultJSONEncoder - defer { - PostgresNIO._defaultJSONEncoder = previousDefaultJSONEncoder - } - final class CustomJSONEncoder: PostgresJSONEncoder { - var didEncode = false - func encode(_ value: T) throws -> Data where T : Encodable { - self.didEncode = true - return try JSONEncoder().encode(value) - } - } - struct Object: Codable { - var foo: Int - var bar: Int - } - let customJSONEncoder = CustomJSONEncoder() - PostgresNIO._defaultJSONEncoder = customJSONEncoder - XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) - XCTAssert(customJSONEncoder.didEncode) - - let customJSONBEncoder = CustomJSONEncoder() - PostgresNIO._defaultJSONEncoder = customJSONBEncoder - XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) - XCTAssert(customJSONBEncoder.didEncode) - } - - // https://github.com/vapor/postgres-nio/issues/126 - func testCustomJSONDecoder() { - let previousDefaultJSONDecoder = PostgresNIO._defaultJSONDecoder - defer { - PostgresNIO._defaultJSONDecoder = previousDefaultJSONDecoder - } - final class CustomJSONDecoder: PostgresJSONDecoder { - var didDecode = false - func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable { - self.didDecode = true - return try JSONDecoder().decode(type, from: data) - } - } - struct Object: Codable { - var foo: Int - var bar: Int - } - let customJSONDecoder = CustomJSONDecoder() - PostgresNIO._defaultJSONDecoder = customJSONDecoder - XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) - XCTAssert(customJSONDecoder.didDecode) - - let customJSONBDecoder = CustomJSONDecoder() - PostgresNIO._defaultJSONDecoder = customJSONBDecoder - XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) - XCTAssert(customJSONBDecoder.didDecode) - } } let isLoggingConfigured: Bool = { diff --git a/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift b/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift new file mode 100644 index 00000000..a8287966 --- /dev/null +++ b/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift @@ -0,0 +1,20 @@ +import PostgresNIO +import XCTest + +class PostgresData_JSONTests: XCTestCase { + func testJSONBConvertible() { + struct Object: PostgresJSONBCodable { + let foo: Int + let bar: Int + } + + XCTAssertEqual(Object.postgresDataType, .jsonb) + + let postgresData = Object(foo: 1, bar: 2).postgresData + XCTAssertEqual(postgresData?.type, .jsonb) + + let object = Object(postgresData: postgresData!) + XCTAssertEqual(object?.foo, 1) + XCTAssertEqual(object?.bar, 2) + } +} diff --git a/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift b/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift new file mode 100644 index 00000000..e9a970ef --- /dev/null +++ b/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift @@ -0,0 +1,37 @@ +import PostgresNIO +import XCTest +import NIOTestUtils + +class PostgresMessageDecoderTests: XCTestCase { + func testMessageDecoder() { + let sample: [UInt8] = [ + 0x52, // R - authentication + 0x00, 0x00, 0x00, 0x0C, // length = 12 + 0x00, 0x00, 0x00, 0x05, // md5 + 0x01, 0x02, 0x03, 0x04, // salt + 0x4B, // B - backend key data + 0x00, 0x00, 0x00, 0x0C, // length = 12 + 0x05, 0x05, 0x05, 0x05, // process id + 0x01, 0x01, 0x01, 0x01, // secret key + ] + var input = ByteBufferAllocator().buffer(capacity: 0) + input.writeBytes(sample) + + let output: [PostgresMessage] = [ + PostgresMessage(identifier: .authentication, bytes: [ + 0x00, 0x00, 0x00, 0x05, + 0x01, 0x02, 0x03, 0x04, + ]), + PostgresMessage(identifier: .backendKeyData, bytes: [ + 0x05, 0x05, 0x05, 0x05, + 0x01, 0x01, 0x01, 0x01, + ]) + ] + XCTAssertNoThrow(try XCTUnwrap(ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(input, output)], + decoderFactory: { + PostgresMessageDecoder() + } + ))) + } +} diff --git a/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift b/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift new file mode 100644 index 00000000..2aad52b6 --- /dev/null +++ b/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift @@ -0,0 +1,61 @@ +import NIOCore +import XCTest +import PostgresNIO + +class PostgresJSONCodingTests: XCTestCase { + // https://github.com/vapor/postgres-nio/issues/126 + func testCustomJSONEncoder() { + let previousDefaultJSONEncoder = PostgresNIO._defaultJSONEncoder + defer { + PostgresNIO._defaultJSONEncoder = previousDefaultJSONEncoder + } + final class CustomJSONEncoder: PostgresJSONEncoder { + var didEncode = false + func encode(_ value: T) throws -> Data where T : Encodable { + self.didEncode = true + return try JSONEncoder().encode(value) + } + } + struct Object: Codable { + var foo: Int + var bar: Int + } + let customJSONEncoder = CustomJSONEncoder() + PostgresNIO._defaultJSONEncoder = customJSONEncoder + XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) + XCTAssert(customJSONEncoder.didEncode) + + let customJSONBEncoder = CustomJSONEncoder() + PostgresNIO._defaultJSONEncoder = customJSONBEncoder + XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) + XCTAssert(customJSONBEncoder.didEncode) + } + + // https://github.com/vapor/postgres-nio/issues/126 + func testCustomJSONDecoder() { + let previousDefaultJSONDecoder = PostgresNIO._defaultJSONDecoder + defer { + PostgresNIO._defaultJSONDecoder = previousDefaultJSONDecoder + } + final class CustomJSONDecoder: PostgresJSONDecoder { + var didDecode = false + func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable { + self.didDecode = true + return try JSONDecoder().decode(type, from: data) + } + } + struct Object: Codable { + var foo: Int + var bar: Int + } + let customJSONDecoder = CustomJSONDecoder() + PostgresNIO._defaultJSONDecoder = customJSONDecoder + XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) + XCTAssert(customJSONDecoder.didDecode) + + let customJSONBDecoder = CustomJSONDecoder() + PostgresNIO._defaultJSONDecoder = customJSONBDecoder + XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) + XCTAssert(customJSONBDecoder.didDecode) + } +} From c0376e0caca73e1e59438bb358b3153500d2b8d1 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 19 Nov 2021 17:04:44 -0600 Subject: [PATCH 034/292] Rejigger test matrix again (#197) --- .github/workflows/test.yml | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d6317e14..68d7cd18 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,6 +10,8 @@ jobs: matrix: swiftver: - swift:5.2 + - swift:5.3 + - swift:5.4 - swift:5.5 - swiftlang/swift:nightly-main swiftos: @@ -18,7 +20,7 @@ jobs: runs-on: ubuntu-latest env: LOG_LEVEL: debug - MATRIX_CONFIG: ${{ toJSON(matrix) }} + MATRIX_CONFIG: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} steps: - name: Check out package uses: actions/checkout@v2 @@ -31,10 +33,8 @@ jobs: exc_prefix="$(which xcrun || true)" && \ ${exc_prefix} llvm-cov export -format lcov \ -instr-profile="$(dirname "$(swift test --show-codecov-path)")/default.profdata" \ - --ignore-filename-regex='/\.build/' \ - --ignore-filename-regex='/Tests/' \ - "$(swift build --show-bin-path)/${pkgname}PackageTests.xctest${subpath}" \ - >"${pkgname}.lcov" + --ignore-filename-regex='/\.build/' --ignore-filename-regex='/Tests/' \ + "$(swift build --show-bin-path)/${pkgname}PackageTests.xctest${subpath}" >"${pkgname}.lcov" echo "CODECOV_FILE=$(pwd)/${pkgname}.lcov" >> $GITHUB_ENV - name: Send coverage report to codecov.io uses: codecov/codecov-action@v2 @@ -60,7 +60,6 @@ jobs: swiftver: - swift:5.2 - swift:5.5 - - swiftlang/swift:nightly-main swiftos: - focal container: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} @@ -126,15 +125,10 @@ jobs: fail-fast: false matrix: dbimage: - # Only test the lastest couple of versions on macOS, let Linux do the rest + # Only test the lastest version on macOS, let Linux do the rest - postgresql@14 - - postgresql@13 - # - postgresql@12 - # - postgresql@11 dbauth: # Only test one auth method on macOS, Linux tests will cover the others - # - trust - # - md5 - scram-sha-256 xcode: - latest-stable From 99673602240e9786e3b6021da4c2b8ce9d041fa2 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 22 Nov 2021 03:43:19 -0600 Subject: [PATCH 035/292] Add workflow to update coverage on pushes to main --- .github/workflows/main-codecov.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 .github/workflows/main-codecov.yml diff --git a/.github/workflows/main-codecov.yml b/.github/workflows/main-codecov.yml new file mode 100644 index 00000000..c6db91d5 --- /dev/null +++ b/.github/workflows/main-codecov.yml @@ -0,0 +1,27 @@ +name: main codecov +on: + push: + branches: + - main +jobs: + update-main-codecov: + runs-on: ubuntu-latest + container: swift:5.5-focal + steps: + - name: Check out main + uses: actions/checkout@v2 + - name: Run unit tests with code coverage and Thread Sanitizer + run: swift test --enable-code-coverage --sanitize=thread --filter=^PostgresNIOTests + - name: Convert profdata to LCOV for upload + run: | + llvm-cov export -format lcov \ + -instr-profile="$(dirname $(swift test --show-codecov-path))/default.profdata" \ + --ignore-filename-regex='/(\.build|Tests)/' \ + "$(swift build --show-bin-path)/postgres-nioPackageTests.xctest" >postgres-nio.lcov + echo "CODECOV_FILE=$(pwd)/postgres-nio.lcov" >>"${GITHUB_ENV}" + - name: Upload LCOV report to Codecov.io + uses: codecov/codecov-action@v2 + with: + files: ${{ env.CODECOV_FILE }} + flags: 'unittests' + fail_ci_if_error: true From 1042870acc1c0aeb74ca44ced238ef2a639ba5b5 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 22 Nov 2021 11:24:23 +0100 Subject: [PATCH 036/292] DataRow without allocation; DataRow as Collection; RowDescription top level (#198) This is a cherry pick of #188. ### Modifications - `DataRow` and `RowDescription` have been moved out of the `PSQLBackendMessage` namespace. This allows us to mark them as `@inlinable` or `@usableFromInline` at a later point, without marking everything in `PSQLBackendMessage` as `@inlinable` - `DataRow` does not use an internal array for its columns anymore. Instead all read operations are directly done on its ByteBuffer slice. - `DataRow` implements the `Collection` protocol now. ### Result One allocation fewer per queried row. --- .../PostgresConnection+Database.swift | 4 +- .../ConnectionStateMachine.swift | 12 +- .../ExtendedQueryStateMachine.swift | 18 +-- .../PrepareStatementStateMachine.swift | 4 +- .../RowStreamStateMachine.swift | 18 +-- .../PostgresNIO/New/Messages/DataRow.swift | 129 ++++++++++++++--- .../New/Messages/RowDescription.swift | 131 +++++++++--------- .../PostgresNIO/New/PSQLChannelHandler.swift | 2 +- Sources/PostgresNIO/New/PSQLConnection.swift | 2 +- .../New/PSQLPreparedStatement.swift | 2 +- Sources/PostgresNIO/New/PSQLRow.swift | 39 +++--- Sources/PostgresNIO/New/PSQLRowStream.swift | 12 +- Sources/PostgresNIO/New/PSQLTask.swift | 6 +- .../ExtendedQueryStateMachineTests.swift | 16 +-- .../PrepareStatementStateMachineTests.swift | 6 +- .../PSQLBackendMessage+Equatable.swift | 8 -- .../PSQLBackendMessageEncoder.swift | 17 +-- .../New/Messages/DataRowTests.swift | 124 +++++++++++++++-- .../New/Messages/RowDescriptionTests.swift | 10 +- 19 files changed, 367 insertions(+), 193 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift index 725f17d8..68e6c96c 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift @@ -50,7 +50,7 @@ extension PostgresConnection: PostgresDatabase { let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) return rows.all().map { allrows in let r = allrows.map { psqlRow -> PostgresRow in - let columns = psqlRow.data.columns.map { + let columns = psqlRow.data.map { PostgresMessage.DataRow.Column(value: $0) } return PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) @@ -112,7 +112,7 @@ extension PSQLRowStream { func iterateRowsWithoutBackpressureOption(lookupTable: PostgresRow.LookupTable, onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { self.onRow { psqlRow in - let columns = psqlRow.data.columns.map { + let columns = psqlRow.data.map { PostgresMessage.DataRow.Column(value: $0) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 1af28a3b..27dd40dc 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -87,18 +87,18 @@ struct ConnectionStateMachine { case sendParseDescribeBindExecuteSync(query: String, binds: [PSQLEncodable]) case sendBindExecuteSync(statementName: String, binds: [PSQLEncodable]) case failQuery(ExtendedQueryContext, with: PSQLError, cleanupContext: CleanUpContext?) - case succeedQuery(ExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column]) + case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) // --- streaming actions // actions if query has requested next row but we are waiting for backend - case forwardRows(CircularBuffer) - case forwardStreamComplete(CircularBuffer, commandTag: String) + case forwardRows([DataRow]) + case forwardStreamComplete([DataRow], commandTag: String) case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?) // Prepare statement actions case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(PrepareStatementContext, with: PSQLBackendMessage.RowDescription?) + case succeedPreparedStatementCreation(PrepareStatementContext, with: RowDescription?) case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError, cleanupContext: CleanUpContext?) // Close actions @@ -713,7 +713,7 @@ struct ConnectionStateMachine { } } - mutating func rowDescriptionReceived(_ description: PSQLBackendMessage.RowDescription) -> ConnectionAction { + mutating func rowDescriptionReceived(_ description: RowDescription) -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: return self.avoidingStateMachineCoW { machine -> ConnectionAction in @@ -791,7 +791,7 @@ struct ConnectionStateMachine { } } - mutating func dataRowReceived(_ dataRow: PSQLBackendMessage.DataRow) -> ConnectionAction { + mutating func dataRowReceived(_ dataRow: DataRow) -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.dataRow(dataRow))) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 4818ca19..67fe219f 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -8,13 +8,13 @@ struct ExtendedQueryStateMachine { case parseCompleteReceived(ExtendedQueryContext) case parameterDescriptionReceived(ExtendedQueryContext) - case rowDescriptionReceived(ExtendedQueryContext, [PSQLBackendMessage.RowDescription.Column]) + case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column]) case noDataMessageReceived(ExtendedQueryContext) /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) - case streaming([PSQLBackendMessage.RowDescription.Column], RowStreamStateMachine) + case streaming([RowDescription.Column], RowStreamStateMachine) case commandComplete(commandTag: String) case error(PSQLError) @@ -28,13 +28,13 @@ struct ExtendedQueryStateMachine { // --- general actions case failQuery(ExtendedQueryContext, with: PSQLError) - case succeedQuery(ExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column]) + case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) // --- streaming actions // actions if query has requested next row but we are waiting for backend - case forwardRows(CircularBuffer) - case forwardStreamComplete(CircularBuffer, commandTag: String) + case forwardRows([DataRow]) + case forwardStreamComplete([DataRow], commandTag: String) case forwardStreamError(PSQLError, read: Bool) case read @@ -105,7 +105,7 @@ struct ExtendedQueryStateMachine { } } - mutating func rowDescriptionReceived(_ rowDescription: PSQLBackendMessage.RowDescription) -> Action { + mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { guard case .parameterDescriptionReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) } @@ -119,7 +119,7 @@ struct ExtendedQueryStateMachine { // In Postgres extended queries we always request the response rows to be returned in // `.binary` format. - let columns = rowDescription.columns.map { column -> PSQLBackendMessage.RowDescription.Column in + let columns = rowDescription.columns.map { column -> RowDescription.Column in var column = column column.format = .binary return column @@ -155,12 +155,12 @@ struct ExtendedQueryStateMachine { } } - mutating func dataRowReceived(_ dataRow: PSQLBackendMessage.DataRow) -> Action { + mutating func dataRowReceived(_ dataRow: DataRow) -> Action { switch self.state { case .streaming(let columns, var demandStateMachine): // When receiving a data row, we must ensure that the data row column count // matches the previously received row description column count. - guard dataRow.columns.count == columns.count else { + guard dataRow.columnCount == columns.count else { return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift index 98e18dbc..947c8f97 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift @@ -15,7 +15,7 @@ struct PrepareStatementStateMachine { enum Action { case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(PrepareStatementContext, with: PSQLBackendMessage.RowDescription?) + case succeedPreparedStatementCreation(PrepareStatementContext, with: RowDescription?) case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError) case read @@ -72,7 +72,7 @@ struct PrepareStatementStateMachine { return .succeedPreparedStatementCreation(queryContext, with: nil) } - mutating func rowDescriptionReceived(_ rowDescription: PSQLBackendMessage.RowDescription) -> Action { + mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { guard case .parameterDescriptionReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift index 165ba4f3..08953fb2 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift @@ -13,15 +13,15 @@ struct RowStreamStateMachine { private enum State { /// The state machines expects further writes to `channelRead`. The writes are appended to the buffer. - case waitingForRows(CircularBuffer) + case waitingForRows([DataRow]) /// The state machines expects a call to `demandMoreResponseBodyParts` or `read`. The buffer is /// empty. It is preserved for performance reasons. - case waitingForReadOrDemand(CircularBuffer) + case waitingForReadOrDemand([DataRow]) /// The state machines expects a call to `read`. The buffer is empty. It is preserved for performance reasons. - case waitingForRead(CircularBuffer) + case waitingForRead([DataRow]) /// The state machines expects a call to `demandMoreResponseBodyParts`. The buffer is empty. It is /// preserved for performance reasons. - case waitingForDemand(CircularBuffer) + case waitingForDemand([DataRow]) case modifying } @@ -29,10 +29,12 @@ struct RowStreamStateMachine { private var state: State init() { - self.state = .waitingForRows(CircularBuffer(initialCapacity: 32)) + var buffer = [DataRow]() + buffer.reserveCapacity(32) + self.state = .waitingForRows(buffer) } - mutating func receivedRow(_ newRow: PSQLBackendMessage.DataRow) { + mutating func receivedRow(_ newRow: DataRow) { switch self.state { case .waitingForRows(var buffer): self.state = .modifying @@ -66,7 +68,7 @@ struct RowStreamStateMachine { } } - mutating func channelReadComplete() -> CircularBuffer? { + mutating func channelReadComplete() -> [DataRow]? { switch self.state { case .waitingForRows(let buffer): if buffer.isEmpty { @@ -139,7 +141,7 @@ struct RowStreamStateMachine { } } - mutating func end() -> CircularBuffer { + mutating func end() -> [DataRow] { switch self.state { case .waitingForRows(let buffer): return buffer diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index 3047ccc2..54044c6a 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -1,34 +1,117 @@ import NIOCore -extension PSQLBackendMessage { +/// A backend data row message. +/// +/// - NOTE: This struct is not part of the ``PSQLBackendMessage`` namespace even +/// though this is where it actually belongs. The reason for this is, that we want +/// this type to be @usableFromInline. If a type is made @usableFromInline in an +/// enclosing type, the enclosing type must be @usableFromInline as well. +/// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick +/// the Swift compiler +struct DataRow: PSQLBackendMessage.PayloadDecodable, Equatable { - struct DataRow: PayloadDecodable, Equatable { - - var columns: [ByteBuffer?] + var columnCount: Int16 + + var bytes: ByteBuffer + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + try buffer.ensureAtLeastNBytesRemaining(2) + let columnCount = buffer.readInteger(as: Int16.self)! + let firstColumnIndex = buffer.readerIndex - static func decode(from buffer: inout ByteBuffer) throws -> Self { + for _ in 0..= 0 else { - result.append(nil) - continue - } - - try buffer.ensureAtLeastNBytesRemaining(bufferLength) - let columnBuffer = buffer.readSlice(length: Int(bufferLength))! - - result.append(columnBuffer) + guard bufferLength >= 0 else { + // if buffer length is negative, this means that the value is null + continue } - return DataRow(columns: result) + try buffer.ensureAtLeastNBytesRemaining(bufferLength) + buffer.moveReaderIndex(forwardBy: bufferLength) + } + + try buffer.ensureExactNBytesRemaining(0) + + buffer.moveReaderIndex(to: firstColumnIndex) + let columnSlice = buffer.readSlice(length: buffer.readableBytes)! + return DataRow(columnCount: columnCount, bytes: columnSlice) + } +} + +extension DataRow: Sequence { + typealias Element = ByteBuffer? + + // There is no contiguous storage available... Sadly + func withContiguousStorageIfAvailable(_ body: (UnsafeBufferPointer) throws -> R) rethrows -> R? { + nil + } +} + +extension DataRow: Collection { + + struct ColumnIndex: Comparable { + var offset: Int + + init(_ index: Int) { + self.offset = index + } + + // Only needed implementation for comparable. The compiler synthesizes the rest from this. + static func < (lhs: Self, rhs: Self) -> Bool { + lhs.offset < rhs.offset + } + } + + typealias Index = DataRow.ColumnIndex + + var startIndex: ColumnIndex { + ColumnIndex(self.bytes.readerIndex) + } + + var endIndex: ColumnIndex { + ColumnIndex(self.bytes.readerIndex + self.bytes.readableBytes) + } + + var count: Int { + Int(self.columnCount) + } + + func index(after index: ColumnIndex) -> ColumnIndex { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + var elementLength = Int(self.bytes.getInteger(at: index.offset, as: Int32.self)!) + if elementLength < 0 { + elementLength = 0 + } + return ColumnIndex(index.offset + MemoryLayout.size + elementLength) + } + + subscript(index: ColumnIndex) -> Element { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + let elementLength = Int(self.bytes.getInteger(at: index.offset, as: Int32.self)!) + if elementLength < 0 { + return nil + } + return self.bytes.getSlice(at: index.offset + MemoryLayout.size, length: elementLength)! + } +} + +extension DataRow { + subscript(column index: Int) -> Element { + guard index < self.columnCount else { + preconditionFailure("index out of bounds") } + + var byteIndex = self.startIndex + for _ in 0.. Self { + try buffer.ensureAtLeastNBytesRemaining(2) + let columnCount = buffer.readInteger(as: Int16.self)! + + guard columnCount >= 0 else { + throw PSQLPartialDecodingError.integerMustBePositiveOrNull(columnCount) } - static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.ensureAtLeastNBytesRemaining(2) - let columnCount = buffer.readInteger(as: Int16.self)! - - guard columnCount >= 0 else { - throw PSQLPartialDecodingError.integerMustBePositiveOrNull(columnCount) + var result = [Column]() + result.reserveCapacity(Int(columnCount)) + + for _ in 0.. EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: PSQLBackendMessage.RowDescription?.self) + let promise = self.channel.eventLoop.makePromise(of: RowDescription?.self) let context = PrepareStatementContext( name: name, query: query, diff --git a/Sources/PostgresNIO/New/PSQLPreparedStatement.swift b/Sources/PostgresNIO/New/PSQLPreparedStatement.swift index c5a08be9..fbdfd868 100644 --- a/Sources/PostgresNIO/New/PSQLPreparedStatement.swift +++ b/Sources/PostgresNIO/New/PSQLPreparedStatement.swift @@ -10,5 +10,5 @@ struct PSQLPreparedStatement { let connection: PSQLConnection /// The `RowDescription` to apply to all `DataRow`s when executing this `PSQLPreparedStatement` - let rowDescription: PSQLBackendMessage.RowDescription? + let rowDescription: RowDescription? } diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift index c5efb53a..e7a6ed7e 100644 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -1,34 +1,22 @@ +import NIOCore /// `PSQLRow` represents a single row that was received from the Postgres Server. struct PSQLRow { internal let lookupTable: [String: Int] - internal let data: PSQLBackendMessage.DataRow + internal let data: DataRow - internal let columns: [PSQLBackendMessage.RowDescription.Column] + internal let columns: [RowDescription.Column] internal let jsonDecoder: PSQLJSONDecoder - internal init(data: PSQLBackendMessage.DataRow, lookupTable: [String: Int], columns: [PSQLBackendMessage.RowDescription.Column], jsonDecoder: PSQLJSONDecoder) { + internal init(data: DataRow, lookupTable: [String: Int], columns: [RowDescription.Column], jsonDecoder: PSQLJSONDecoder) { self.data = data self.lookupTable = lookupTable self.columns = columns self.jsonDecoder = jsonDecoder } - - /// Access the raw Postgres data in the n-th column - subscript(index: Int) -> PSQLData { - PSQLData(bytes: self.data.columns[index], dataType: self.columns[index].dataType, format: self.columns[index].format) - } - - // TBD: Should this be optional? - /// Access the raw Postgres data in the column indentified by name - subscript(column columnName: String) -> PSQLData? { - guard let index = self.lookupTable[columnName] else { - return nil - } - - return self[index] - } - +} + +extension PSQLRow { /// Access the data in the provided column and decode it into the target type. /// /// - Parameters: @@ -52,15 +40,20 @@ struct PSQLRow { /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. /// - Returns: The decoded value of Type T. func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { - let column = self.columns[index] + precondition(index < self.data.columnCount) - let decodingContext = PSQLDecodingContext( - jsonDecoder: jsonDecoder, + let column = self.columns[index] + let context = PSQLDecodingContext( + jsonDecoder: self.jsonDecoder, columnName: column.name, columnIndex: index, file: file, line: line) - return try self[index].decode(as: T.self, context: decodingContext) + guard var cellSlice = self.data[column: index] else { + throw PSQLCastingError.missingData(targetType: T.self, type: column.dataType, context: context) + } + + return try T.decode(from: &cellSlice, type: column.dataType, format: column.format, context: context) } } diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 768255fb..e3d74f16 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -12,8 +12,8 @@ final class PSQLRowStream { let logger: Logger private enum UpstreamState { - case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) - case finished(buffer: CircularBuffer, commandTag: String) + case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) + case finished(buffer: CircularBuffer, commandTag: String) case failure(Error) case consumed(Result) case modifying @@ -25,18 +25,18 @@ final class PSQLRowStream { case consuming } - internal let rowDescription: [PSQLBackendMessage.RowDescription.Column] + internal let rowDescription: [RowDescription.Column] private let lookupTable: [String: Int] private var upstreamState: UpstreamState private var downstreamState: DownstreamState private let jsonDecoder: PSQLJSONDecoder - init(rowDescription: [PSQLBackendMessage.RowDescription.Column], + init(rowDescription: [RowDescription.Column], queryContext: ExtendedQueryContext, eventLoop: EventLoop, rowSource: RowSource) { - let buffer = CircularBuffer() + let buffer = CircularBuffer() self.downstreamState = .consuming switch rowSource { @@ -186,7 +186,7 @@ final class PSQLRowStream { ]) } - internal func receive(_ newRows: CircularBuffer) { + internal func receive(_ newRows: [DataRow]) { precondition(!newRows.isEmpty, "Expected to get rows!") self.eventLoop.preconditionInEventLoop() self.logger.trace("Row stream received rows", metadata: [ diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index af3e8ee4..1f7a06d6 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -21,7 +21,7 @@ enum PSQLTask { final class ExtendedQueryContext { enum Query { case unnamed(String) - case preparedStatement(name: String, rowDescription: PSQLBackendMessage.RowDescription?) + case preparedStatement(name: String, rowDescription: RowDescription?) } let query: Query @@ -65,12 +65,12 @@ final class PrepareStatementContext { let name: String let query: String let logger: Logger - let promise: EventLoopPromise + let promise: EventLoopPromise init(name: String, query: String, logger: Logger, - promise: EventLoopPromise) + promise: EventLoopPromise) { self.name = name self.query = query diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index e1076a6e..39360645 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -40,25 +40,25 @@ class ExtendedQueryStateMachineTests: XCTestCase { // We need to ensure that even though the row description from the wire says that we // will receive data in `.text` format, we will actually receive it in binary format, // since we requested it in binary with our bind message. - let input: [PSQLBackendMessage.RowDescription.Column] = [ + let input: [RowDescription.Column] = [ .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) ] - let expected: [PSQLBackendMessage.RowDescription.Column] = input.map { + let expected: [RowDescription.Column] = input.map { .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) - let row1: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test1")] + let row1: DataRow = [ByteBuffer(string: "test1")] XCTAssertEqual(state.dataRowReceived(row1), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) XCTAssertEqual(state.readEventCaught(), .wait) XCTAssertEqual(state.requestQueryRows(), .read) - let row2: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test2")] - let row3: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test3")] - let row4: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test4")] + let row2: DataRow = [ByteBuffer(string: "test2")] + let row3: DataRow = [ByteBuffer(string: "test3")] + let row4: DataRow = [ByteBuffer(string: "test4")] XCTAssertEqual(state.dataRowReceived(row2), .wait) XCTAssertEqual(state.dataRowReceived(row3), .wait) XCTAssertEqual(state.dataRowReceived(row4), .wait) @@ -69,8 +69,8 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.readEventCaught(), .read) - let row5: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test5")] - let row6: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test6")] + let row5: DataRow = [ByteBuffer(string: "test5")] + let row6: DataRow = [ByteBuffer(string: "test6")] XCTAssertEqual(state.dataRowReceived(row5), .wait) XCTAssertEqual(state.dataRowReceived(row6), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 9b88af9a..6cff280e 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -7,7 +7,7 @@ class PrepareStatementStateMachineTests: XCTestCase { func testCreatePreparedStatementReturningRowDescription() { var state = ConnectionStateMachine.readyForQuery() - let promise = EmbeddedEventLoop().makePromise(of: PSQLBackendMessage.RowDescription?.self) + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let name = "haha" @@ -20,7 +20,7 @@ class PrepareStatementStateMachineTests: XCTestCase { XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) - let columns: [PSQLBackendMessage.RowDescription.Column] = [ + let columns: [RowDescription.Column] = [ .init(name: "id", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: -1, format: .binary) ] @@ -32,7 +32,7 @@ class PrepareStatementStateMachineTests: XCTestCase { func testCreatePreparedStatementReturningNoData() { var state = ConnectionStateMachine.readyForQuery() - let promise = EmbeddedEventLoop().makePromise(of: PSQLBackendMessage.RowDescription?.self) + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let name = "haha" diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift index 8434e761..436c7aa9 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift @@ -47,11 +47,3 @@ extension PSQLBackendMessage: Equatable { } } } - -extension PSQLBackendMessage.DataRow: ExpressibleByArrayLiteral { - public typealias ArrayLiteralElement = ByteBuffer - - public init(arrayLiteral elements: ByteBuffer...) { - self.init(columns: elements) - } -} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index ea5323ec..75cb1afc 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -188,19 +188,10 @@ extension PSQLBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.DataRow: PSQLMessagePayloadEncodable { +extension DataRow: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { - buffer.writeInteger(Int16(self.columns.count)) - - for column in self.columns { - switch column { - case .none: - buffer.writeInteger(-1, as: Int32.self) - case .some(var writable): - buffer.writeInteger(Int32(writable.readableBytes)) - buffer.writeBuffer(&writable) - } - } + buffer.writeInteger(self.columnCount, as: Int16.self) + buffer.writeBytes(self.bytes.readableBytesView) } } @@ -255,7 +246,7 @@ extension PSQLBackendMessage.TransactionState: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.RowDescription: PSQLMessagePayloadEncodable { +extension RowDescription: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(Int16(self.columns.count)) diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index af9ee3f2..7db44547 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -20,18 +20,126 @@ class DataRowTests: XCTestCase { buffer.writeBytes([UInt8](repeating: 5, count: 10)) } - let expectedColumns: [ByteBuffer?] = [ - nil, - ByteBuffer(), - ByteBuffer(bytes: [UInt8](repeating: 5, count: 10)) - ] - + let rowSlice = buffer.getSlice(at: 7, length: buffer.readableBytes - 7)! + let expectedInOuts = [ - (buffer, [PSQLBackendMessage.dataRow(.init(columns: expectedColumns))]), + (buffer, [PSQLBackendMessage.dataRow(.init(columnCount: 3, bytes: rowSlice))]), ] - + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expectedInOuts, decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } + + func testIteratingElements() { + let dataRow = DataRow.makeTestDataRow(nil, ByteBuffer(), ByteBuffer(repeating: 5, count: 10)) + var iterator = dataRow.makeIterator() + + XCTAssertEqual(dataRow.count, 3) + XCTAssertEqual(iterator.next(), .some(.none)) + XCTAssertEqual(iterator.next(), ByteBuffer()) + XCTAssertEqual(iterator.next(), ByteBuffer(repeating: 5, count: 10)) + XCTAssertEqual(iterator.next(), .none) + } + + func testIndexAfterAndSubscript() { + let dataRow = DataRow.makeTestDataRow( + nil, + ByteBuffer(), + ByteBuffer(repeating: 5, count: 10), + nil + ) + + var index = dataRow.startIndex + XCTAssertEqual(dataRow[index], .none) + index = dataRow.index(after: index) + XCTAssertEqual(dataRow[index], ByteBuffer()) + index = dataRow.index(after: index) + XCTAssertEqual(dataRow[index], ByteBuffer(repeating: 5, count: 10)) + index = dataRow.index(after: index) + XCTAssertEqual(dataRow[index], .none) + index = dataRow.index(after: index) + XCTAssertEqual(index, dataRow.endIndex) + } + + func testIndexComparison() { + let dataRow = DataRow.makeTestDataRow( + nil, + ByteBuffer(), + ByteBuffer(repeating: 5, count: 10), + nil + ) + + let startIndex = dataRow.startIndex + let secondIndex = dataRow.index(after: startIndex) + + XCTAssertLessThanOrEqual(startIndex, secondIndex) + XCTAssertLessThan(startIndex, secondIndex) + + XCTAssertGreaterThanOrEqual(secondIndex, startIndex) + XCTAssertGreaterThan(secondIndex, startIndex) + + XCTAssertFalse(secondIndex == startIndex) + XCTAssertEqual(secondIndex, secondIndex) + XCTAssertEqual(startIndex, startIndex) + } + + func testColumnSubscript() { + let dataRow = DataRow.makeTestDataRow( + nil, + ByteBuffer(), + ByteBuffer(repeating: 5, count: 10), + nil + ) + + XCTAssertEqual(dataRow.count, 4) + XCTAssertEqual(dataRow[column: 0], .none) + XCTAssertEqual(dataRow[column: 1], ByteBuffer()) + XCTAssertEqual(dataRow[column: 2], ByteBuffer(repeating: 5, count: 10)) + XCTAssertEqual(dataRow[column: 3], .none) + } + + func testWithContiguousStorageIfAvailable() { + let dataRow = DataRow.makeTestDataRow( + nil, + ByteBuffer(), + ByteBuffer(repeating: 5, count: 10), + nil + ) + + XCTAssertNil(dataRow.withContiguousStorageIfAvailable { _ in + return XCTFail("DataRow does not have a contiguous storage") + }) + } } + +extension DataRow: ExpressibleByArrayLiteral { + public typealias ArrayLiteralElement = PSQLEncodable + + public init(arrayLiteral elements: PSQLEncodable...) { + + var buffer = ByteBuffer() + let encodingContext = PSQLEncodingContext(jsonEncoder: JSONEncoder()) + elements.forEach { element in + try! element.encodeRaw(into: &buffer, context: encodingContext) + } + + self.init(columnCount: Int16(elements.count), bytes: buffer) + } + + static func makeTestDataRow(_ buffers: ByteBuffer?...) -> DataRow { + var bytes = ByteBuffer() + buffers.forEach { column in + switch column { + case .none: + bytes.writeInteger(Int32(-1)) + case .some(var input): + bytes.writeInteger(Int32(input.readableBytes)) + bytes.writeBuffer(&input) + } + } + + return DataRow(columnCount: Int16(buffers.count), bytes: bytes) + } +} + diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index 4452ebce..8eba059d 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -6,7 +6,7 @@ import NIOTestUtils class RowDescriptionTests: XCTestCase { func testDecode() { - let columns: [PSQLBackendMessage.RowDescription.Column] = [ + let columns: [RowDescription.Column] = [ .init(name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary), .init(name: "Second", tableOID: 123, columnAttributeNumber: 456, dataType: .uuidArray, dataTypeSize: 567, dataTypeModifier: 123, format: .text), ] @@ -42,7 +42,7 @@ class RowDescriptionTests: XCTestCase { } func testDecodeFailureBecauseOfMissingNullTerminationInColumnName() { - let column = PSQLBackendMessage.RowDescription.Column( + let column = RowDescription.Column( name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() @@ -65,7 +65,7 @@ class RowDescriptionTests: XCTestCase { } func testDecodeFailureBecauseOfMissingColumnCount() { - let column = PSQLBackendMessage.RowDescription.Column( + let column = RowDescription.Column( name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() @@ -87,7 +87,7 @@ class RowDescriptionTests: XCTestCase { } func testDecodeFailureBecauseInvalidFormatCode() { - let column = PSQLBackendMessage.RowDescription.Column( + let column = RowDescription.Column( name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() @@ -110,7 +110,7 @@ class RowDescriptionTests: XCTestCase { } func testDecodeFailureBecauseNegativeColumnCount() { - let column = PSQLBackendMessage.RowDescription.Column( + let column = RowDescription.Column( name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() From 87cfca5324fa0592f03e6c5f8b4e0c7a14f371b8 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 22 Nov 2021 06:14:04 -0600 Subject: [PATCH 037/292] Fix a few (more) oopsies in the CI (#200) Related to vapor/postgres-kit#214 and vapor/fluent-postgres-driver#186 --- .github/workflows/test.yml | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 68d7cd18..36799584 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,8 +24,8 @@ jobs: steps: - name: Check out package uses: actions/checkout@v2 - - name: Run unit tests with code coverage - run: swift test --enable-test-discovery --filter=^PostgresNIOTests --enable-code-coverage + - name: Run unit tests with code coverage and Thread Sanitizer + run: swift test --enable-test-discovery --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage - name: Convert code coverage report to most expressive format run: | export pkgname="$(swift package dump-package | perl -e 'use JSON::PP; print (decode_json(join("",(<>)))->{name});')" \ @@ -33,9 +33,9 @@ jobs: exc_prefix="$(which xcrun || true)" && \ ${exc_prefix} llvm-cov export -format lcov \ -instr-profile="$(dirname "$(swift test --show-codecov-path)")/default.profdata" \ - --ignore-filename-regex='/\.build/' --ignore-filename-regex='/Tests/' \ + --ignore-filename-regex='/(\.build|Tests)/' \ "$(swift build --show-bin-path)/${pkgname}PackageTests.xctest${subpath}" >"${pkgname}.lcov" - echo "CODECOV_FILE=$(pwd)/${pkgname}.lcov" >> $GITHUB_ENV + echo "CODECOV_FILE=$(pwd)/${pkgname}.lcov" >> "${GITHUB_ENV}" - name: Send coverage report to codecov.io uses: codecov/codecov-action@v2 with: @@ -58,7 +58,8 @@ jobs: - md5 - scram-sha-256 swiftver: - - swift:5.2 + # Only test latest Swift for integration tests, issues from older Swift versions that don't show + # up in the unit tests are fairly unlikely. - swift:5.5 swiftos: - focal @@ -67,18 +68,16 @@ jobs: env: LOG_LEVEL: debug # Unfortunately, fluent-postgres-driver details leak through here - POSTGRES_HOSTNAME: 'psql-a' POSTGRES_DB: 'test_database' - POSTGRES_DATABASE: 'test_database' - POSTGRES_DATABASE_A: 'test_database' - POSTGRES_DATABASE_B: 'test_database' + POSTGRES_DB_A: 'test_database' + POSTGRES_DB_B: 'test_database' POSTGRES_USER: 'test_username' - POSTGRES_USERNAME: 'test_username' - POSTGRES_USERNAME_A: 'test_username' - POSTGRES_USERNAME_B: 'test_username' + POSTGRES_USER_A: 'test_username' + POSTGRES_USER_B: 'test_username' POSTGRES_PASSWORD: 'test_password' POSTGRES_PASSWORD_A: 'test_password' POSTGRES_PASSWORD_B: 'test_password' + POSTGRES_HOSTNAME: 'psql-a' POSTGRES_HOSTNAME_A: 'psql-a' POSTGRES_HOSTNAME_B: 'psql-b' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} @@ -104,7 +103,7 @@ jobs: uses: actions/checkout@v2 with: { path: 'postgres-nio' } - name: Run integration tests - run: swift test --package-path postgres-nio --enable-test-discovery --filter=^IntegrationTests + run: swift test --package-path postgres-nio --filter=^IntegrationTests - name: Check out postgres-kit dependent uses: actions/checkout@v2 with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } @@ -116,9 +115,9 @@ jobs: swift package --package-path postgres-kit edit postgres-nio --path postgres-nio swift package --package-path fluent-postgres-driver edit postgres-nio --path postgres-nio - name: Run postgres-kit tests - run: swift test --package-path postgres-kit --enable-test-discovery + run: swift test --package-path postgres-kit - name: Run fluent-postgres-driver tests - run: swift test --package-path fluent-postgres-driver --enable-test-discovery + run: swift test --package-path fluent-postgres-driver macos-all: strategy: @@ -148,8 +147,8 @@ jobs: xcode-version: ${{ matrix.xcode }} - name: Install Postgres, setup DB and auth, and wait for server start run: | - export PATH="$(brew prefix)/opt/${{ matrix.dbimage }}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - brew install ${{ matrix.dbimage }} + export PATH="$(brew --prefix)/opt/${{ matrix.dbimage }}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test + (brew unlink postgresql || true) && brew install ${{ matrix.dbimage }} && brew link --force ${{ matrix.dbimage }} initdb --locale=C --auth-host ${{ matrix.dbauth }} -U $POSTGRES_USER --pwfile=<(echo $POSTGRES_PASSWORD) pg_ctl start --wait timeout-minutes: 2 From 3931a0694ea0960c82db4115fb66cfbc7dee8844 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 23 Nov 2021 14:17:57 +0100 Subject: [PATCH 038/292] ByteBuffer extension, prevent naming conflicts (#204) ### Motivation Because of https://bugs.swift.org/browse/SR-15517, we might run into naming conflicts with SwiftNIO, once https://github.com/apple/swift-nio/pull/1990 lands. ### Changes - Prefix all ByteBuffer utility methods ### Result Chances of breaking code reduced. --- .../Data/PostgresData+Double.swift | 6 +++--- .../PostgresNIO/Data/PostgresData+Float.swift | 4 ++-- .../PostgresMessage+Authentication.swift | 4 ++-- .../Message/PostgresMessage+Bind.swift | 4 ++-- .../Message/PostgresMessage+Close.swift | 2 +- .../PostgresMessage+CommandComplete.swift | 2 +- .../Message/PostgresMessage+Describe.swift | 2 +- .../Message/PostgresMessage+Error.swift | 2 +- .../Message/PostgresMessage+Execute.swift | 2 +- ...PostgresMessage+NotificationResponse.swift | 4 ++-- .../PostgresMessage+ParameterStatus.swift | 4 ++-- .../PostgresMessage+RowDescription.swift | 2 +- .../PostgresMessage+SASLResponse.swift | 4 ++-- .../New/Data/Float+PSQLCodable.swift | 12 +++++------ .../New/Extensions/ByteBuffer+PSQL.swift | 16 +++++++-------- .../New/Messages/Authentication.swift | 6 +++--- .../New/Messages/BackendKeyData.swift | 2 +- Sources/PostgresNIO/New/Messages/Bind.swift | 4 ++-- Sources/PostgresNIO/New/Messages/Close.swift | 4 ++-- .../PostgresNIO/New/Messages/DataRow.swift | 8 ++++---- .../PostgresNIO/New/Messages/Describe.swift | 4 ++-- .../New/Messages/ErrorResponse.swift | 2 +- .../PostgresNIO/New/Messages/Execute.swift | 2 +- .../New/Messages/NotificationResponse.swift | 6 +++--- .../New/Messages/ParameterDescription.swift | 4 ++-- .../New/Messages/ParameterStatus.swift | 4 ++-- Sources/PostgresNIO/New/Messages/Parse.swift | 4 ++-- .../PostgresNIO/New/Messages/Password.swift | 2 +- .../New/Messages/ReadyForQuery.swift | 2 +- .../New/Messages/RowDescription.swift | 6 +++--- .../New/Messages/SASLInitialResponse.swift | 2 +- .../PostgresNIO/New/Messages/Startup.swift | 14 ++++++------- .../PostgresNIO/New/PSQLBackendMessage.swift | 14 ++++++------- .../New/PSQLBackendMessageDecoder.swift | 4 ++-- .../New/PSQLFrontendMessageEncoder.swift | 2 +- .../New/Extensions/ByteBuffer+Utils.swift | 2 +- .../PSQLBackendMessageEncoder.swift | 20 +++++++++---------- .../PSQLFrontendMessageDecoder.swift | 6 +++--- .../New/Messages/BackendKeyDataTests.swift | 2 +- .../New/Messages/BindTests.swift | 4 ++-- .../New/Messages/CloseTests.swift | 4 ++-- .../New/Messages/DescribeTests.swift | 4 ++-- .../New/Messages/ErrorResponseTests.swift | 2 +- .../New/Messages/ExecuteTests.swift | 2 +- .../Messages/NotificationResponseTests.swift | 6 +++--- .../New/Messages/ParameterStatusTests.swift | 6 +++--- .../New/Messages/ParseTests.swift | 4 ++-- .../New/Messages/PasswordTests.swift | 2 +- .../New/Messages/RowDescriptionTests.swift | 8 ++++---- .../Messages/SASLInitialResponseTests.swift | 4 ++-- .../New/Messages/StartupTests.swift | 16 +++++++-------- .../New/PSQLBackendMessageTests.swift | 8 ++++---- 52 files changed, 133 insertions(+), 133 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresData+Double.swift b/Sources/PostgresNIO/Data/PostgresData+Double.swift index 7435cdaa..986f8e23 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Double.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Double.swift @@ -3,7 +3,7 @@ import NIOCore extension PostgresData { public init(double: Double) { var buffer = ByteBufferAllocator().buffer(capacity: 0) - buffer.writeDouble(double) + buffer.psqlWriteDouble(double) self.init(type: .float8, formatCode: .binary, value: buffer) } @@ -16,10 +16,10 @@ extension PostgresData { case .binary: switch self.type { case .float4: - return value.readFloat() + return value.psqlReadFloat() .flatMap { Double($0) } case .float8: - return value.readDouble() + return value.psqlReadDouble() case .numeric: return self.numeric?.double default: diff --git a/Sources/PostgresNIO/Data/PostgresData+Float.swift b/Sources/PostgresNIO/Data/PostgresData+Float.swift index e9b7b572..9931ae9c 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Float.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Float.swift @@ -12,9 +12,9 @@ extension PostgresData { case .binary: switch self.type { case .float4: - return value.readFloat() + return value.psqlReadFloat() case .float8: - return value.readDouble() + return value.psqlReadDouble() .flatMap { Float($0) } default: return nil diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift b/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift index 44523a5c..e849b29d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift @@ -23,7 +23,7 @@ extension PostgresMessage { case 10: var mechanisms: [String] = [] while buffer.readableBytes > 0 { - guard let nextString = buffer.readNullTerminatedString() else { + guard let nextString = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Could not parse SASL mechanisms from authentication message") } if nextString.isEmpty { @@ -68,7 +68,7 @@ extension PostgresMessage { case .saslMechanisms(let mechanisms): buffer.writeInteger(10, as: Int32.self) mechanisms.forEach { - buffer.writeNullTerminatedString($0) + buffer.psqlWriteNullTerminatedString($0) } case .saslContinue(let challenge): buffer.writeInteger(11, as: Int32.self) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift b/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift index a5687c40..7e85f57c 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift @@ -39,8 +39,8 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.portalName) - buffer.writeNullTerminatedString(self.statementName) + buffer.psqlWriteNullTerminatedString(self.portalName) + buffer.psqlWriteNullTerminatedString(self.statementName) buffer.write(array: self.parameterFormatCodes) buffer.write(array: self.parameters) { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift b/Sources/PostgresNIO/Message/PostgresMessage+Close.swift index 9e5dd99e..6d974ec2 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Close.swift @@ -33,7 +33,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) throws { buffer.writeInteger(target.rawValue) - buffer.writeNullTerminatedString(name) + buffer.psqlWriteNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift b/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift index 406dc036..7e3035ac 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift @@ -5,7 +5,7 @@ extension PostgresMessage { public struct CommandComplete: PostgresMessageType { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> CommandComplete { - guard let string = buffer.readNullTerminatedString() else { + guard let string = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Could not parse close response message") } return .init(tag: string) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift b/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift index 8c3bc8f5..c41e5b44 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift @@ -31,7 +31,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { buffer.writeInteger(command.rawValue) - buffer.writeNullTerminatedString(name) + buffer.psqlWriteNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift index 51b9be7e..6aca3387 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift @@ -11,7 +11,7 @@ extension PostgresMessage { public static func parse(from buffer: inout ByteBuffer) throws -> Error { var fields: [Field: String] = [:] while let field = buffer.readInteger(as: Field.self) { - guard let string = buffer.readNullTerminatedString() else { + guard let string = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Could not read error response string.") } fields[field] = string diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift b/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift index 4b8bc999..3451ef64 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift @@ -20,7 +20,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(portalName) + buffer.psqlWriteNullTerminatedString(portalName) buffer.writeInteger(self.maxRows) } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift index 4979e354..27d8df80 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift @@ -10,10 +10,10 @@ extension PostgresMessage { guard let backendPID: Int32 = buffer.readInteger() else { throw PostgresError.protocol("Invalid NotificationResponse message: unable to read backend PID") } - guard let channel = buffer.readNullTerminatedString() else { + guard let channel = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Invalid NotificationResponse message: unable to read channel") } - guard let payload = buffer.readNullTerminatedString() else { + guard let payload = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Invalid NotificationResponse message: unable to read payload") } return .init(backendPID: backendPID, channel: channel, payload: payload) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift b/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift index 5e2f5881..59af4c1f 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift @@ -4,10 +4,10 @@ extension PostgresMessage { public struct ParameterStatus: PostgresMessageType, CustomStringConvertible { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ParameterStatus { - guard let parameter = buffer.readNullTerminatedString() else { + guard let parameter = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Could not read parameter from parameter status message") } - guard let value = buffer.readNullTerminatedString() else { + guard let value = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Could not read value from parameter status message") } return .init(parameter: parameter, value: value) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift index 48a90c18..cddaac1d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift @@ -11,7 +11,7 @@ extension PostgresMessage { /// Describes a single field returns in a `RowDescription` message. public struct Field: CustomStringConvertible { static func parse(from buffer: inout ByteBuffer) throws -> Field { - guard let name = buffer.readNullTerminatedString() else { + guard let name = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Could not read row description field name") } guard let tableOID = buffer.readInteger(as: UInt32.self) else { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift b/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift index 553edc2c..66b4cb5f 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift @@ -38,7 +38,7 @@ extension PostgresMessage { public let initialData: [UInt8] public static func parse(from buffer: inout ByteBuffer) throws -> PostgresMessage.SASLInitialResponse { - guard let mechanism = buffer.readNullTerminatedString() else { + guard let mechanism = buffer.psqlReadNullTerminatedString() else { throw PostgresError.protocol("Could not parse SASL mechanism from initial response message") } guard let dataLength = buffer.readInteger(as: Int32.self) else { @@ -57,7 +57,7 @@ extension PostgresMessage { } public func serialize(into buffer: inout ByteBuffer) throws { - buffer.writeNullTerminatedString(mechanism) + buffer.psqlWriteNullTerminatedString(mechanism) if initialData.count > 0 { buffer.writeInteger(Int32(initialData.count), as: Int32.self) // write(array:) writes Int16, which is incorrect here buffer.writeBytes(initialData) diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index e86894a2..6a551e64 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -12,12 +12,12 @@ extension Float: PSQLCodable { static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Float { switch (format, type) { case (.binary, .float4): - guard buffer.readableBytes == 4, let float = buffer.readFloat() else { + guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return float case (.binary, .float8): - guard buffer.readableBytes == 8, let double = buffer.readDouble() else { + guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return Float(double) @@ -32,7 +32,7 @@ extension Float: PSQLCodable { } func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { - byteBuffer.writeFloat(self) + byteBuffer.psqlWriteFloat(self) } } @@ -48,12 +48,12 @@ extension Double: PSQLCodable { static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Double { switch (format, type) { case (.binary, .float4): - guard buffer.readableBytes == 4, let float = buffer.readFloat() else { + guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return Double(float) case (.binary, .float8): - guard buffer.readableBytes == 8, let double = buffer.readDouble() else { + guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return double @@ -68,7 +68,7 @@ extension Double: PSQLCodable { } func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { - byteBuffer.writeDouble(self) + byteBuffer.psqlWriteDouble(self) } } diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 45197cc0..79d5256e 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -1,12 +1,12 @@ import NIOCore internal extension ByteBuffer { - mutating func writeNullTerminatedString(_ string: String) { + mutating func psqlWriteNullTerminatedString(_ string: String) { self.writeString(string) self.writeInteger(0, as: UInt8.self) } - mutating func readNullTerminatedString() -> String? { + mutating func psqlReadNullTerminatedString() -> String? { guard let nullIndex = readableBytesView.firstIndex(of: 0) else { return nil } @@ -15,27 +15,27 @@ internal extension ByteBuffer { return readString(length: nullIndex - readerIndex) } - mutating func writeBackendMessageID(_ messageID: PSQLBackendMessage.ID) { + mutating func psqlWriteBackendMessageID(_ messageID: PSQLBackendMessage.ID) { self.writeInteger(messageID.rawValue) } - mutating func writeFrontendMessageID(_ messageID: PSQLFrontendMessage.ID) { + mutating func psqlWriteFrontendMessageID(_ messageID: PSQLFrontendMessage.ID) { self.writeInteger(messageID.rawValue) } - mutating func readFloat() -> Float? { + mutating func psqlReadFloat() -> Float? { return self.readInteger(as: UInt32.self).map { Float(bitPattern: $0) } } - mutating func readDouble() -> Double? { + mutating func psqlReadDouble() -> Double? { return self.readInteger(as: UInt64.self).map { Double(bitPattern: $0) } } - mutating func writeFloat(_ float: Float) { + mutating func psqlWriteFloat(_ float: Float) { self.writeInteger(float.bitPattern) } - mutating func writeDouble(_ double: Double) { + mutating func psqlWriteDouble(_ double: Double) { self.writeInteger(double.bitPattern) } } diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift index 5ce5b857..92b000a0 100644 --- a/Sources/PostgresNIO/New/Messages/Authentication.swift +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -16,7 +16,7 @@ extension PSQLBackendMessage { case saslFinal(data: ByteBuffer) static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.ensureAtLeastNBytesRemaining(2) + try buffer.psqlEnsureAtLeastNBytesRemaining(2) // we have at least two bytes remaining, therefore we can force unwrap this read. let authID = buffer.readInteger(as: Int32.self)! @@ -29,7 +29,7 @@ extension PSQLBackendMessage { case 3: return .plaintext case 5: - try buffer.ensureExactNBytesRemaining(4) + try buffer.psqlEnsureExactNBytesRemaining(4) let salt1 = buffer.readInteger(as: UInt8.self)! let salt2 = buffer.readInteger(as: UInt8.self)! let salt3 = buffer.readInteger(as: UInt8.self)! @@ -47,7 +47,7 @@ extension PSQLBackendMessage { case 10: var names = [String]() let endIndex = buffer.readerIndex + buffer.readableBytes - while buffer.readerIndex < endIndex, let next = buffer.readNullTerminatedString() { + while buffer.readerIndex < endIndex, let next = buffer.psqlReadNullTerminatedString() { names.append(next) } diff --git a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift index dfb5738e..fdc41439 100644 --- a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift +++ b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift @@ -7,7 +7,7 @@ extension PSQLBackendMessage { let secretKey: Int32 static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.ensureExactNBytesRemaining(8) + try buffer.psqlEnsureExactNBytesRemaining(8) // We have verified the correct length before, this means we have exactly eight bytes // to read. If we have enough readable bytes, a read of Int32 should always succeed. diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index 110d7866..dd3465b2 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -13,8 +13,8 @@ extension PSQLFrontendMessage { var parameters: [PSQLEncodable] func encode(into buffer: inout ByteBuffer, using jsonEncoder: PSQLJSONEncoder) throws { - buffer.writeNullTerminatedString(self.portalName) - buffer.writeNullTerminatedString(self.preparedStatementName) + buffer.psqlWriteNullTerminatedString(self.portalName) + buffer.psqlWriteNullTerminatedString(self.preparedStatementName) // The number of parameter format codes that follow (denoted C below). This can be // zero to indicate that there are no parameters or that the parameters all use the diff --git a/Sources/PostgresNIO/New/Messages/Close.swift b/Sources/PostgresNIO/New/Messages/Close.swift index 5ed532e6..ae70f758 100644 --- a/Sources/PostgresNIO/New/Messages/Close.swift +++ b/Sources/PostgresNIO/New/Messages/Close.swift @@ -10,10 +10,10 @@ extension PSQLFrontendMessage { switch self { case .preparedStatement(let name): buffer.writeInteger(UInt8(ascii: "S")) - buffer.writeNullTerminatedString(name) + buffer.psqlWriteNullTerminatedString(name) case .portal(let name): buffer.writeInteger(UInt8(ascii: "P")) - buffer.writeNullTerminatedString(name) + buffer.psqlWriteNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index 54044c6a..1828128b 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -15,12 +15,12 @@ struct DataRow: PSQLBackendMessage.PayloadDecodable, Equatable { var bytes: ByteBuffer static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.ensureAtLeastNBytesRemaining(2) + try buffer.psqlEnsureAtLeastNBytesRemaining(2) let columnCount = buffer.readInteger(as: Int16.self)! let firstColumnIndex = buffer.readerIndex for _ in 0..= 0 else { @@ -28,11 +28,11 @@ struct DataRow: PSQLBackendMessage.PayloadDecodable, Equatable { continue } - try buffer.ensureAtLeastNBytesRemaining(bufferLength) + try buffer.psqlEnsureAtLeastNBytesRemaining(bufferLength) buffer.moveReaderIndex(forwardBy: bufferLength) } - try buffer.ensureExactNBytesRemaining(0) + try buffer.psqlEnsureExactNBytesRemaining(0) buffer.moveReaderIndex(to: firstColumnIndex) let columnSlice = buffer.readSlice(length: buffer.readableBytes)! diff --git a/Sources/PostgresNIO/New/Messages/Describe.swift b/Sources/PostgresNIO/New/Messages/Describe.swift index 0a3105cc..104d7127 100644 --- a/Sources/PostgresNIO/New/Messages/Describe.swift +++ b/Sources/PostgresNIO/New/Messages/Describe.swift @@ -11,10 +11,10 @@ extension PSQLFrontendMessage { switch self { case .preparedStatement(let name): buffer.writeInteger(UInt8(ascii: "S")) - buffer.writeNullTerminatedString(name) + buffer.psqlWriteNullTerminatedString(name) case .portal(let name): buffer.writeInteger(UInt8(ascii: "P")) - buffer.writeNullTerminatedString(name) + buffer.psqlWriteNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift index 254cdf0f..891c7e9b 100644 --- a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift +++ b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift @@ -117,7 +117,7 @@ extension PSQLBackendMessage.PayloadDecodable where Self: PSQLMessageNotice { asType: PSQLBackendMessage.Field.self) } - guard let string = buffer.readNullTerminatedString() else { + guard let string = buffer.psqlReadNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } fields[field] = string diff --git a/Sources/PostgresNIO/New/Messages/Execute.swift b/Sources/PostgresNIO/New/Messages/Execute.swift index 891bd9aa..2cf13922 100644 --- a/Sources/PostgresNIO/New/Messages/Execute.swift +++ b/Sources/PostgresNIO/New/Messages/Execute.swift @@ -15,7 +15,7 @@ extension PSQLFrontendMessage { } func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.portalName) + buffer.psqlWriteNullTerminatedString(self.portalName) buffer.writeInteger(self.maxNumberOfRows) } } diff --git a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift index b1430e2a..afc860fc 100644 --- a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift +++ b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift @@ -8,13 +8,13 @@ extension PSQLBackendMessage { let payload: String static func decode(from buffer: inout ByteBuffer) throws -> PSQLBackendMessage.NotificationResponse { - try buffer.ensureAtLeastNBytesRemaining(6) + try buffer.psqlEnsureAtLeastNBytesRemaining(6) let backendPID = buffer.readInteger(as: Int32.self)! - guard let channel = buffer.readNullTerminatedString() else { + guard let channel = buffer.psqlReadNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } - guard let payload = buffer.readNullTerminatedString() else { + guard let payload = buffer.psqlReadNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index fdf64aad..49062fda 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -7,14 +7,14 @@ extension PSQLBackendMessage { var dataTypes: [PSQLDataType] static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.ensureAtLeastNBytesRemaining(2) + try buffer.psqlEnsureAtLeastNBytesRemaining(2) let parameterCount = buffer.readInteger(as: Int16.self)! guard parameterCount >= 0 else { throw PSQLPartialDecodingError.integerMustBePositiveOrNull(parameterCount) } - try buffer.ensureExactNBytesRemaining(Int(parameterCount) * 4) + try buffer.psqlEnsureExactNBytesRemaining(Int(parameterCount) * 4) var result = [PSQLDataType]() result.reserveCapacity(Int(parameterCount)) diff --git a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift index 89dd1d6d..ebf1e212 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift @@ -10,11 +10,11 @@ extension PSQLBackendMessage { var value: String static func decode(from buffer: inout ByteBuffer) throws -> Self { - guard let name = buffer.readNullTerminatedString() else { + guard let name = buffer.psqlReadNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } - guard let value = buffer.readNullTerminatedString() else { + guard let value = buffer.psqlReadNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } diff --git a/Sources/PostgresNIO/New/Messages/Parse.swift b/Sources/PostgresNIO/New/Messages/Parse.swift index 1d0aec19..72eb4962 100644 --- a/Sources/PostgresNIO/New/Messages/Parse.swift +++ b/Sources/PostgresNIO/New/Messages/Parse.swift @@ -13,8 +13,8 @@ extension PSQLFrontendMessage { let parameters: [PSQLDataType] func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.preparedStatementName) - buffer.writeNullTerminatedString(self.query) + buffer.psqlWriteNullTerminatedString(self.preparedStatementName) + buffer.psqlWriteNullTerminatedString(self.query) buffer.writeInteger(Int16(self.parameters.count)) self.parameters.forEach { dataType in diff --git a/Sources/PostgresNIO/New/Messages/Password.swift b/Sources/PostgresNIO/New/Messages/Password.swift index 88e885f9..df1bd327 100644 --- a/Sources/PostgresNIO/New/Messages/Password.swift +++ b/Sources/PostgresNIO/New/Messages/Password.swift @@ -6,7 +6,7 @@ extension PSQLFrontendMessage { let value: String func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(value) + buffer.psqlWriteNullTerminatedString(value) } } diff --git a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift index 20420763..74b30200 100644 --- a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift +++ b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift @@ -33,7 +33,7 @@ extension PSQLBackendMessage { } static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.ensureExactNBytesRemaining(1) + try buffer.psqlEnsureExactNBytesRemaining(1) // Exactly one byte is readable. For this reason, we can force unwrap the UInt8 below let value = buffer.readInteger(as: UInt8.self)! diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index 49f09baa..ade0e85c 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -37,7 +37,7 @@ struct RowDescription: PSQLBackendMessage.PayloadDecodable, Equatable { } static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.ensureAtLeastNBytesRemaining(2) + try buffer.psqlEnsureAtLeastNBytesRemaining(2) let columnCount = buffer.readInteger(as: Int16.self)! guard columnCount >= 0 else { @@ -48,11 +48,11 @@ struct RowDescription: PSQLBackendMessage.PayloadDecodable, Equatable { result.reserveCapacity(Int(columnCount)) for _ in 0.. 0 { buffer.writeInteger(Int32(self.initialData.count)) diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift index 148b8bc2..0ceb1050 100644 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ b/Sources/PostgresNIO/New/Messages/Startup.swift @@ -51,29 +51,29 @@ extension PSQLFrontendMessage { /// Serializes this message into a byte buffer. func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.protocolVersion) - buffer.writeNullTerminatedString("user") + buffer.psqlWriteNullTerminatedString("user") buffer.writeString(self.parameters.user) buffer.writeInteger(UInt8(0)) if let database = self.parameters.database { - buffer.writeNullTerminatedString("database") + buffer.psqlWriteNullTerminatedString("database") buffer.writeString(database) buffer.writeInteger(UInt8(0)) } if let options = self.parameters.options { - buffer.writeNullTerminatedString("options") + buffer.psqlWriteNullTerminatedString("options") buffer.writeString(options) buffer.writeInteger(UInt8(0)) } switch self.parameters.replication { case .database: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("replication") + buffer.psqlWriteNullTerminatedString("replication") + buffer.psqlWriteNullTerminatedString("replication") case .true: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("true") + buffer.psqlWriteNullTerminatedString("replication") + buffer.psqlWriteNullTerminatedString("true") case .false: break } diff --git a/Sources/PostgresNIO/New/PSQLBackendMessage.swift b/Sources/PostgresNIO/New/PSQLBackendMessage.swift index d65f4623..c71789f1 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessage.swift +++ b/Sources/PostgresNIO/New/PSQLBackendMessage.swift @@ -193,27 +193,27 @@ extension PSQLBackendMessage { case .backendKeyData: return try .backendKeyData(.decode(from: &buffer)) case .bindComplete: - try buffer.ensureExactNBytesRemaining(0) + try buffer.psqlEnsureExactNBytesRemaining(0) return .bindComplete case .closeComplete: - try buffer.ensureExactNBytesRemaining(0) + try buffer.psqlEnsureExactNBytesRemaining(0) return .closeComplete case .commandComplete: - guard let commandTag = buffer.readNullTerminatedString() else { + guard let commandTag = buffer.psqlReadNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } return .commandComplete(commandTag) case .dataRow: return try .dataRow(.decode(from: &buffer)) case .emptyQueryResponse: - try buffer.ensureExactNBytesRemaining(0) + try buffer.psqlEnsureExactNBytesRemaining(0) return .emptyQueryResponse case .parameterStatus: return try .parameterStatus(.decode(from: &buffer)) case .error: return try .error(.decode(from: &buffer)) case .noData: - try buffer.ensureExactNBytesRemaining(0) + try buffer.psqlEnsureExactNBytesRemaining(0) return .noData case .noticeResponse: return try .notice(.decode(from: &buffer)) @@ -222,10 +222,10 @@ extension PSQLBackendMessage { case .parameterDescription: return try .parameterDescription(.decode(from: &buffer)) case .parseComplete: - try buffer.ensureExactNBytesRemaining(0) + try buffer.psqlEnsureExactNBytesRemaining(0) return .parseComplete case .portalSuspended: - try buffer.ensureExactNBytesRemaining(0) + try buffer.psqlEnsureExactNBytesRemaining(0) return .portalSuspended case .readyForQuery: return try .readyForQuery(.decode(from: &buffer)) diff --git a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift index edf386df..dd4e4ebf 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift @@ -192,13 +192,13 @@ struct PSQLPartialDecodingError: Error { } extension ByteBuffer { - func ensureAtLeastNBytesRemaining(_ n: Int, file: String = #file, line: Int = #line) throws { + func psqlEnsureAtLeastNBytesRemaining(_ n: Int, file: String = #file, line: Int = #line) throws { guard self.readableBytes >= n else { throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: self.readableBytes, file: file, line: line) } } - func ensureExactNBytesRemaining(_ n: Int, file: String = #file, line: Int = #line) throws { + func psqlEnsureExactNBytesRemaining(_ n: Int, file: String = #file, line: Int = #line) throws { guard self.readableBytes == n else { throw PSQLPartialDecodingError.expectedExactlyNRemainingBytes(n, actual: self.readableBytes, file: file, line: line) } diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift index 0a998285..227cd233 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift @@ -71,7 +71,7 @@ struct PSQLFrontendMessageEncoder: MessageToByteEncoder { payload: Payload, into buffer: inout ByteBuffer) { - buffer.writeFrontendMessageID(messageID) + buffer.psqlWriteFrontendMessageID(messageID) self.encode(payload: payload, into: &buffer) } diff --git a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift index 9d1cfb81..835965da 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift @@ -10,7 +10,7 @@ extension ByteBuffer { } mutating func writeBackendMessage(id: PSQLBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows { - self.writeBackendMessageID(id) + self.psqlWriteBackendMessageID(id) let lengthIndex = self.writerIndex self.writeInteger(Int32(0)) try payload(&self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 75cb1afc..6c1be6f5 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -68,7 +68,7 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { var string: String init(_ string: String) { self.string = string } func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.string) + buffer.psqlWriteNullTerminatedString(self.string) } } @@ -77,7 +77,7 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { payload: Payload, into buffer: inout ByteBuffer) { - buffer.writeBackendMessageID(messageID) + buffer.psqlWriteBackendMessageID(messageID) let startIndex = buffer.writerIndex buffer.writeInteger(Int32(0)) // placeholder for length payload.encode(into: &buffer) @@ -166,7 +166,7 @@ extension PSQLBackendMessage.Authentication: PSQLMessagePayloadEncodable { case .sasl(names: let names): buffer.writeInteger(Int32(10)) for name in names { - buffer.writeNullTerminatedString(name) + buffer.psqlWriteNullTerminatedString(name) } case .saslContinue(data: var data): @@ -199,7 +199,7 @@ extension PSQLBackendMessage.ErrorResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { for (key, value) in self.fields { buffer.writeInteger(key.rawValue, as: UInt8.self) - buffer.writeNullTerminatedString(value) + buffer.psqlWriteNullTerminatedString(value) } buffer.writeInteger(0, as: UInt8.self) // signal done } @@ -209,7 +209,7 @@ extension PSQLBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { for (key, value) in self.fields { buffer.writeInteger(key.rawValue, as: UInt8.self) - buffer.writeNullTerminatedString(value) + buffer.psqlWriteNullTerminatedString(value) } buffer.writeInteger(0, as: UInt8.self) // signal done } @@ -218,8 +218,8 @@ extension PSQLBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { extension PSQLBackendMessage.NotificationResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.backendPID) - buffer.writeNullTerminatedString(self.channel) - buffer.writeNullTerminatedString(self.payload) + buffer.psqlWriteNullTerminatedString(self.channel) + buffer.psqlWriteNullTerminatedString(self.payload) } } @@ -235,8 +235,8 @@ extension PSQLBackendMessage.ParameterDescription: PSQLMessagePayloadEncodable { extension PSQLBackendMessage.ParameterStatus: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.parameter) - buffer.writeNullTerminatedString(self.value) + buffer.psqlWriteNullTerminatedString(self.parameter) + buffer.psqlWriteNullTerminatedString(self.value) } } @@ -251,7 +251,7 @@ extension RowDescription: PSQLMessagePayloadEncodable { buffer.writeInteger(Int16(self.columns.count)) for column in self.columns { - buffer.writeNullTerminatedString(column.name) + buffer.psqlWriteNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index c639f4b2..4bf988ae 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -40,8 +40,8 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { var database: String? var options: String? - while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex { - let value = messageSlice.readNullTerminatedString() + while let name = messageSlice.psqlReadNullTerminatedString(), messageSlice.readerIndex < finalIndex { + let value = messageSlice.psqlReadNullTerminatedString() switch name { case "user": @@ -136,7 +136,7 @@ extension PSQLFrontendMessage { case .parse: preconditionFailure("TODO: Unimplemented") case .password: - guard let password = buffer.readNullTerminatedString() else { + guard let password = buffer.psqlReadNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } return .password(.init(value: password)) diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift index eca5ba02..5715c61c 100644 --- a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -21,7 +21,7 @@ class BackendKeyDataTests: XCTestCase { func testDecodeInvalidLength() { var buffer = ByteBuffer() - buffer.writeBackendMessageID(.backendKeyData) + buffer.psqlWriteBackendMessageID(.backendKeyData) buffer.writeInteger(Int32(11)) buffer.writeInteger(Int32(1234)) buffer.writeInteger(Int32(4567)) diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 7a688d41..234e1541 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -14,8 +14,8 @@ class BindTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, 37) XCTAssertEqual(PSQLFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) + XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) // the number of parameters XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) // all (two) parameters have the same format (binary) diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index 4df15896..8f8af2bd 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -14,7 +14,7 @@ class CloseTests: XCTestCase { XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) + XCTAssertEqual("Hello", byteBuffer.psqlReadNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } @@ -28,7 +28,7 @@ class CloseTests: XCTestCase { XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index 87f7d09b..fabb0e29 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -14,7 +14,7 @@ class DescribeTests: XCTestCase { XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) + XCTAssertEqual("Hello", byteBuffer.psqlReadNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } @@ -28,7 +28,7 @@ class DescribeTests: XCTestCase { XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift index bbc945e4..df0d63b0 100644 --- a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift @@ -19,7 +19,7 @@ class ErrorResponseTests: XCTestCase { let buffer = ByteBuffer.backendMessage(id: .error) { buffer in fields.forEach { (key, value) in buffer.writeInteger(key.rawValue, as: UInt8.self) - buffer.writeNullTerminatedString(value) + buffer.psqlWriteNullTerminatedString(value) } buffer.writeInteger(0, as: UInt8.self) // signal done } diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 3ce8d63d..0969194c 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -13,7 +13,7 @@ class ExecuteTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) XCTAssertEqual(PSQLFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) XCTAssertEqual(0, byteBuffer.readInteger(as: Int32.self)) } } diff --git a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift index 39fbb220..abf6b4ed 100644 --- a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift @@ -20,8 +20,8 @@ class NotificationResponseTests: XCTestCase { buffer.writeBackendMessage(id: .notificationResponse) { buffer in buffer.writeInteger(notification.backendPID) - buffer.writeNullTerminatedString(notification.channel) - buffer.writeNullTerminatedString(notification.payload) + buffer.psqlWriteNullTerminatedString(notification.channel) + buffer.psqlWriteNullTerminatedString(notification.payload) } } @@ -49,7 +49,7 @@ class NotificationResponseTests: XCTestCase { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .notificationResponse) { buffer in buffer.writeInteger(Int32(123)) - buffer.writeNullTerminatedString("hello") + buffer.psqlWriteNullTerminatedString("hello") buffer.writeString("world") } diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift index db4963e0..2f00fa53 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift @@ -27,8 +27,8 @@ class ParameterStatusTests: XCTestCase { switch message { case .parameterStatus(let parameterStatus): buffer.writeBackendMessage(id: .parameterStatus) { buffer in - buffer.writeNullTerminatedString(parameterStatus.parameter) - buffer.writeNullTerminatedString(parameterStatus.value) + buffer.psqlWriteNullTerminatedString(parameterStatus.parameter) + buffer.psqlWriteNullTerminatedString(parameterStatus.value) } case .backendKeyData(let backendKeyData): buffer.writeBackendMessage(id: .backendKeyData) { buffer in @@ -62,7 +62,7 @@ class ParameterStatusTests: XCTestCase { func testDecodeFailureBecauseOfMissingNullTerminationInValue() { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .parameterStatus) { buffer in - buffer.writeNullTerminatedString("DateStyle") + buffer.psqlWriteNullTerminatedString("DateStyle") buffer.writeString("ISO, MDY") } diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index c147b749..3393e74d 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -24,8 +24,8 @@ class ParseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.parse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query) + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), parse.preparedStatementName) + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), parse.query) XCTAssertEqual(byteBuffer.readInteger(as: Int16.self), Int16(parse.parameters.count)) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.bool.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.int8.rawValue) diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 73c464f3..f7876426 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -16,6 +16,6 @@ class PasswordTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, expectedLength) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.password.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(expectedLength - 1)) // length - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") } } diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index 8eba059d..ba759dc4 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -25,7 +25,7 @@ class RowDescriptionTests: XCTestCase { buffer.writeInteger(Int16(description.columns.count)) description.columns.forEach { column in - buffer.writeNullTerminatedString(column.name) + buffer.psqlWriteNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) @@ -70,7 +70,7 @@ class RowDescriptionTests: XCTestCase { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in - buffer.writeNullTerminatedString(column.name) + buffer.psqlWriteNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) @@ -93,7 +93,7 @@ class RowDescriptionTests: XCTestCase { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in buffer.writeInteger(Int16(1)) - buffer.writeNullTerminatedString(column.name) + buffer.psqlWriteNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) @@ -116,7 +116,7 @@ class RowDescriptionTests: XCTestCase { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in buffer.writeInteger(Int16(-1)) - buffer.writeNullTerminatedString(column.name) + buffer.psqlWriteNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index af2459ac..3c4ae4b3 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -23,7 +23,7 @@ class SASLInitialResponseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), sasl.saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(sasl.initialData.count)) XCTAssertEqual(byteBuffer.readBytes(length: sasl.initialData.count), sasl.initialData) XCTAssertEqual(byteBuffer.readableBytes, 0) @@ -48,7 +48,7 @@ class SASLInitialResponseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), sasl.saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 1224aede..ee63ea1a 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -29,15 +29,15 @@ class StartupTests: XCTestCase { let byteBufferLength = Int32(byteBuffer.readableBytes) XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) XCTAssertEqual(startup.protocolVersion, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some options") + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "options") + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "some options") if replication != .false { - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "replication") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), replication.stringValue) + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "replication") + XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), replication.stringValue) } XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index 049e23d1..0f486180 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -90,8 +90,8 @@ class PSQLBackendMessageTests: XCTestCase { parameterStatus.forEach { parameterStatus in buffer.writeBackendMessage(id: .parameterStatus) { buffer in - buffer.writeNullTerminatedString(parameterStatus.parameter) - buffer.writeNullTerminatedString(parameterStatus.value) + buffer.psqlWriteNullTerminatedString(parameterStatus.parameter) + buffer.psqlWriteNullTerminatedString(parameterStatus.value) } expectedMessages.append(.parameterStatus(parameterStatus)) @@ -132,7 +132,7 @@ class PSQLBackendMessageTests: XCTestCase { buffer.writeBackendMessage(id: .noticeResponse) { buffer in fields.forEach { (key, value) in buffer.writeInteger(key.rawValue, as: UInt8.self) - buffer.writeNullTerminatedString(value) + buffer.psqlWriteNullTerminatedString(value) } buffer.writeInteger(0, as: UInt8.self) // signal done } @@ -216,7 +216,7 @@ class PSQLBackendMessageTests: XCTestCase { } okBuffer.writeBackendMessage(id: .commandComplete) { buffer in - buffer.writeNullTerminatedString(commandTag) + buffer.psqlWriteNullTerminatedString(commandTag) } } From f91f23db099199b0846b0fdee1b02a9b00c4b749 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 23 Nov 2021 08:29:00 -0600 Subject: [PATCH 039/292] Switch to our codecov-action wrapper (#202) * Switch to our codecov-action wrapper * Update the code coverage action for main branch too --- .github/workflows/main-codecov.yml | 19 +++++++------------ .github/workflows/test.yml | 26 +++++++------------------- 2 files changed, 14 insertions(+), 31 deletions(-) diff --git a/.github/workflows/main-codecov.yml b/.github/workflows/main-codecov.yml index c6db91d5..7a55c3ae 100644 --- a/.github/workflows/main-codecov.yml +++ b/.github/workflows/main-codecov.yml @@ -12,16 +12,11 @@ jobs: uses: actions/checkout@v2 - name: Run unit tests with code coverage and Thread Sanitizer run: swift test --enable-code-coverage --sanitize=thread --filter=^PostgresNIOTests - - name: Convert profdata to LCOV for upload - run: | - llvm-cov export -format lcov \ - -instr-profile="$(dirname $(swift test --show-codecov-path))/default.profdata" \ - --ignore-filename-regex='/(\.build|Tests)/' \ - "$(swift build --show-bin-path)/postgres-nioPackageTests.xctest" >postgres-nio.lcov - echo "CODECOV_FILE=$(pwd)/postgres-nio.lcov" >>"${GITHUB_ENV}" - - name: Upload LCOV report to Codecov.io - uses: codecov/codecov-action@v2 + - name: Submit coverage report to Codecov.io + uses: vapor/swift-codecov-action@v0.1.1 with: - files: ${{ env.CODECOV_FILE }} - flags: 'unittests' - fail_ci_if_error: true + cc_flags: 'unittests' + cc_env_vars: 'SWIFT_VERSION,SWIFT_PLATFORM,RUNNER_OS,RUNNER_ARCH' + cc_fail_ci_if_error: true + cc_verbose: true + cc_dry_run: false diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 36799584..ccf4b474 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,5 @@ name: test on: [ 'pull_request' ] -env: - LOG_LEVEL: notice jobs: linux-unit: @@ -20,29 +18,19 @@ jobs: runs-on: ubuntu-latest env: LOG_LEVEL: debug - MATRIX_CONFIG: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} steps: - name: Check out package uses: actions/checkout@v2 - name: Run unit tests with code coverage and Thread Sanitizer run: swift test --enable-test-discovery --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage - - name: Convert code coverage report to most expressive format - run: | - export pkgname="$(swift package dump-package | perl -e 'use JSON::PP; print (decode_json(join("",(<>)))->{name});')" \ - subpath="$([ "$(uname -s)" = 'Darwin' ] && echo "/Contents/MacOS/${pkgname}PackageTests" || true)" \ - exc_prefix="$(which xcrun || true)" && \ - ${exc_prefix} llvm-cov export -format lcov \ - -instr-profile="$(dirname "$(swift test --show-codecov-path)")/default.profdata" \ - --ignore-filename-regex='/(\.build|Tests)/' \ - "$(swift build --show-bin-path)/${pkgname}PackageTests.xctest${subpath}" >"${pkgname}.lcov" - echo "CODECOV_FILE=$(pwd)/${pkgname}.lcov" >> "${GITHUB_ENV}" - - name: Send coverage report to codecov.io - uses: codecov/codecov-action@v2 + - name: Submit coverage report to Codecov.io + uses: vapor/swift-codecov-action@v0.1.1 with: - files: ${{ env.CODECOV_FILE }} - flags: 'unittests' - env_vars: 'MATRIX_CONFIG' - fail_ci_if_error: true + cc_flags: 'unittests' + cc_env_vars: 'SWIFT_VERSION,SWIFT_PLATFORM,RUNNER_OS,RUNNER_ARCH' + cc_fail_ci_if_error: true + cc_verbose: true + cc_dry_run: false linux-integration-and-dependencies: strategy: From 2c49bee33daa81f9bb8617910de8167152f84969 Mon Sep 17 00:00:00 2001 From: Mads Odgaard Date: Wed, 24 Nov 2021 19:46:56 +0100 Subject: [PATCH 040/292] Add proper support for `Decimal` (#194) * Use `PostgresNumeric` for `Decimal` instead of String * Make `Decimal` conform to `PSQLCodable` * Fix support for text decimals * Add integration test for decimal string serialization * Test inserting decimal to text column Co-authored-by: Gwynne Raskind --- .../Data/PostgresData+Decimal.swift | 4 +- .../New/Data/Decimal+PSQLCodable.swift | 39 +++++++++++++++++++ .../PSQLIntegrationTests.swift | 25 ++++++++++++ Tests/IntegrationTests/PostgresNIOTests.swift | 36 ++++++++++++++--- .../New/Data/Decimal+PSQLCodableTests.swift | 32 +++++++++++++++ 5 files changed, 128 insertions(+), 8 deletions(-) create mode 100644 Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift create mode 100644 Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift diff --git a/Sources/PostgresNIO/Data/PostgresData+Decimal.swift b/Sources/PostgresNIO/Data/PostgresData+Decimal.swift index f98e06af..0d2047b6 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Decimal.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Decimal.swift @@ -18,7 +18,7 @@ extension PostgresData { extension Decimal: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { - return String.postgresDataType + return .numeric } public init?(postgresData: PostgresData) { @@ -29,6 +29,6 @@ extension Decimal: PostgresDataConvertible { } public var postgresData: PostgresData? { - return .init(decimal: self) + return .init(numeric: PostgresNumeric(decimal: self)) } } diff --git a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift new file mode 100644 index 00000000..de42a874 --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift @@ -0,0 +1,39 @@ +import NIOCore +import struct Foundation.Decimal + +extension Decimal: PSQLCodable { + var psqlType: PSQLDataType { + .numeric + } + + var psqlFormat: PSQLFormat { + .binary + } + + static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Decimal { + switch (format, type) { + case (.binary, .numeric): + guard let numeric = PostgresNumeric(buffer: &byteBuffer) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context) + } + return numeric.decimal + case (.text, .numeric): + guard let string = byteBuffer.readString(length: byteBuffer.readableBytes), let value = Decimal(string: string) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context) + } + return value + default: + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context) + } + } + + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + let numeric = PostgresNumeric(decimal: self) + byteBuffer.writeInteger(numeric.ndigits) + byteBuffer.writeInteger(numeric.weight) + byteBuffer.writeInteger(numeric.sign) + byteBuffer.writeInteger(numeric.dscale) + var value = numeric.value + byteBuffer.writeBuffer(&value) + } +} diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index dabe9f1c..f3d63add 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -251,6 +251,31 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(try row?.decode(column: "timestamptz", as: Date.self).description, "2016-01-18 00:20:03 +0000") } + func testDecodeDecimals() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PSQLConnection? + XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var stream: PSQLRowStream? + XCTAssertNoThrow(stream = try conn?.query(""" + SELECT + $1::numeric as numeric, + $2::numeric as numeric_negative + """, [Decimal(string: "123456.789123")!, Decimal(string: "-123456.789123")!], logger: .psqlTest).wait()) + + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try stream?.all().wait()) + XCTAssertEqual(rows?.count, 1) + let row = rows?.first + + XCTAssertEqual(try row?.decode(column: "numeric", as: Decimal.self), Decimal(string: "123456.789123")!) + XCTAssertEqual(try row?.decode(column: "numeric_negative", as: Decimal.self), Decimal(string: "-123456.789123")!) + } + func testDecodeUUID() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 739735bb..ff1fb804 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -466,17 +466,41 @@ final class PostgresNIOTests: XCTestCase { var rows: PostgresQueryResult? XCTAssertNoThrow(rows = try conn?.query(""" select - $1::numeric::text as a, - $2::numeric::text as b, - $3::numeric::text as c + $1::numeric as a, + $2::numeric as b, + $3::numeric as c """, [ .init(numeric: a), .init(numeric: b), .init(numeric: c) ]).wait()) - XCTAssertEqual(rows?.first?.column("a")?.string, "123456.789123") - XCTAssertEqual(rows?.first?.column("b")?.string, "-123456.789123") - XCTAssertEqual(rows?.first?.column("c")?.string, "3.14159265358979") + XCTAssertEqual(rows?.first?.column("a")?.decimal, Decimal(string: "123456.789123")!) + XCTAssertEqual(rows?.first?.column("b")?.decimal, Decimal(string: "-123456.789123")!) + XCTAssertEqual(rows?.first?.column("c")?.decimal, Decimal(string: "3.14159265358979")!) + } + + func testDecimalStringSerialization() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + + XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE IF EXISTS \"table1\"").wait()) + XCTAssertNoThrow(_ = try conn?.simpleQuery(""" + CREATE TABLE table1 ( + "balance" text NOT NULL + ); + """).wait()) + defer { XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE \"table1\"").wait()) } + + XCTAssertNoThrow(_ = try conn?.query("INSERT INTO table1 VALUES ($1)", [.init(decimal: Decimal(string: "123456.789123")!)]).wait()) + + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + SELECT + "balance" + FROM table1 + """).wait()) + XCTAssertEqual(rows?.first?.column("balance")?.decimal, Decimal(string: "123456.789123")!) } func testMoney() { diff --git a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift new file mode 100644 index 00000000..afdcad20 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift @@ -0,0 +1,32 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Decimal_PSQLCodableTests: XCTestCase { + + func testRoundTrip() { + let values: [Decimal] = [1.1, .pi, -5e-12] + + for value in values { + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .forTests()) + XCTAssertEqual(value.psqlType, .numeric) + let data = PSQLData(bytes: buffer, dataType: .numeric, format: .binary) + + var result: Decimal? + XCTAssertNoThrow(result = try data.decode(as: Decimal.self, context: .forTests())) + XCTAssertEqual(value, result) + } + } + + func testDecodeFailureInvalidType() { + var buffer = ByteBuffer() + buffer.writeInteger(Int64(0)) + let data = PSQLData(bytes: buffer, dataType: .int8, format: .binary) + + XCTAssertThrowsError(try data.decode(as: Decimal.self, context: .forTests())) { error in + XCTAssert(error is PSQLCastingError) + } + } + +} From 81ca9092902556bfd9d9c3f1819fcea757d37ff9 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 27 Nov 2021 00:06:50 +0100 Subject: [PATCH 041/292] Faster decoding, thanks to fewer bound checks. (#203) --- Package.swift | 2 +- .../PostgresMessage+Authentication.swift | 4 +- .../Message/PostgresMessage+Bind.swift | 4 +- .../Message/PostgresMessage+Close.swift | 2 +- .../PostgresMessage+CommandComplete.swift | 2 +- .../Message/PostgresMessage+Describe.swift | 2 +- .../Message/PostgresMessage+Error.swift | 2 +- .../Message/PostgresMessage+Execute.swift | 2 +- ...PostgresMessage+NotificationResponse.swift | 4 +- .../PostgresMessage+ParameterStatus.swift | 4 +- .../PostgresMessage+RowDescription.swift | 2 +- .../PostgresMessage+SASLResponse.swift | 4 +- .../New/Data/Array+PSQLCodable.swift | 23 +++--- .../New/Extensions/ByteBuffer+PSQL.swift | 13 ---- .../New/Messages/Authentication.swift | 17 ++--- .../New/Messages/BackendKeyData.swift | 12 ++-- Sources/PostgresNIO/New/Messages/Bind.swift | 4 +- Sources/PostgresNIO/New/Messages/Cancel.swift | 4 +- Sources/PostgresNIO/New/Messages/Close.swift | 4 +- .../PostgresNIO/New/Messages/DataRow.swift | 14 ++-- .../PostgresNIO/New/Messages/Describe.swift | 4 +- .../New/Messages/ErrorResponse.swift | 2 +- .../PostgresNIO/New/Messages/Execute.swift | 2 +- .../New/Messages/NotificationResponse.swift | 7 +- .../New/Messages/ParameterDescription.swift | 8 +-- .../New/Messages/ParameterStatus.swift | 4 +- Sources/PostgresNIO/New/Messages/Parse.swift | 4 +- .../PostgresNIO/New/Messages/Password.swift | 2 +- .../New/Messages/ReadyForQuery.swift | 5 +- .../New/Messages/RowDescription.swift | 18 ++--- .../New/Messages/SASLInitialResponse.swift | 2 +- .../PostgresNIO/New/Messages/Startup.swift | 23 +++--- .../PostgresNIO/New/PSQLBackendMessage.swift | 25 +++++-- .../New/PSQLBackendMessageDecoder.swift | 71 +++++++++---------- .../PSQLBackendMessageEncoder.swift | 18 ++--- .../PSQLFrontendMessageDecoder.swift | 6 +- .../New/Messages/BindTests.swift | 4 +- .../New/Messages/CloseTests.swift | 4 +- .../New/Messages/DescribeTests.swift | 4 +- .../New/Messages/ErrorResponseTests.swift | 2 +- .../New/Messages/ExecuteTests.swift | 2 +- .../Messages/NotificationResponseTests.swift | 6 +- .../New/Messages/ParameterStatusTests.swift | 6 +- .../New/Messages/ParseTests.swift | 4 +- .../New/Messages/PasswordTests.swift | 2 +- .../New/Messages/RowDescriptionTests.swift | 8 +-- .../Messages/SASLInitialResponseTests.swift | 4 +- .../New/Messages/StartupTests.swift | 16 ++--- .../New/PSQLBackendMessageTests.swift | 8 +-- 49 files changed, 178 insertions(+), 218 deletions(-) diff --git a/Package.swift b/Package.swift index 64c261b3..510c04fe 100644 --- a/Package.swift +++ b/Package.swift @@ -13,7 +13,7 @@ let package = Package( .library(name: "PostgresNIO", targets: ["PostgresNIO"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.35.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift b/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift index e849b29d..44523a5c 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift @@ -23,7 +23,7 @@ extension PostgresMessage { case 10: var mechanisms: [String] = [] while buffer.readableBytes > 0 { - guard let nextString = buffer.psqlReadNullTerminatedString() else { + guard let nextString = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Could not parse SASL mechanisms from authentication message") } if nextString.isEmpty { @@ -68,7 +68,7 @@ extension PostgresMessage { case .saslMechanisms(let mechanisms): buffer.writeInteger(10, as: Int32.self) mechanisms.forEach { - buffer.psqlWriteNullTerminatedString($0) + buffer.writeNullTerminatedString($0) } case .saslContinue(let challenge): buffer.writeInteger(11, as: Int32.self) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift b/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift index 7e85f57c..a5687c40 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift @@ -39,8 +39,8 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { - buffer.psqlWriteNullTerminatedString(self.portalName) - buffer.psqlWriteNullTerminatedString(self.statementName) + buffer.writeNullTerminatedString(self.portalName) + buffer.writeNullTerminatedString(self.statementName) buffer.write(array: self.parameterFormatCodes) buffer.write(array: self.parameters) { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift b/Sources/PostgresNIO/Message/PostgresMessage+Close.swift index 6d974ec2..9e5dd99e 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Close.swift @@ -33,7 +33,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) throws { buffer.writeInteger(target.rawValue) - buffer.psqlWriteNullTerminatedString(name) + buffer.writeNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift b/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift index 7e3035ac..406dc036 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift @@ -5,7 +5,7 @@ extension PostgresMessage { public struct CommandComplete: PostgresMessageType { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> CommandComplete { - guard let string = buffer.psqlReadNullTerminatedString() else { + guard let string = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Could not parse close response message") } return .init(tag: string) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift b/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift index c41e5b44..8c3bc8f5 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift @@ -31,7 +31,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { buffer.writeInteger(command.rawValue) - buffer.psqlWriteNullTerminatedString(name) + buffer.writeNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift index 6aca3387..51b9be7e 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift @@ -11,7 +11,7 @@ extension PostgresMessage { public static func parse(from buffer: inout ByteBuffer) throws -> Error { var fields: [Field: String] = [:] while let field = buffer.readInteger(as: Field.self) { - guard let string = buffer.psqlReadNullTerminatedString() else { + guard let string = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Could not read error response string.") } fields[field] = string diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift b/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift index 3451ef64..4b8bc999 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift @@ -20,7 +20,7 @@ extension PostgresMessage { /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { - buffer.psqlWriteNullTerminatedString(portalName) + buffer.writeNullTerminatedString(portalName) buffer.writeInteger(self.maxRows) } } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift index 27d8df80..4979e354 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift @@ -10,10 +10,10 @@ extension PostgresMessage { guard let backendPID: Int32 = buffer.readInteger() else { throw PostgresError.protocol("Invalid NotificationResponse message: unable to read backend PID") } - guard let channel = buffer.psqlReadNullTerminatedString() else { + guard let channel = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Invalid NotificationResponse message: unable to read channel") } - guard let payload = buffer.psqlReadNullTerminatedString() else { + guard let payload = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Invalid NotificationResponse message: unable to read payload") } return .init(backendPID: backendPID, channel: channel, payload: payload) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift b/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift index 59af4c1f..5e2f5881 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift @@ -4,10 +4,10 @@ extension PostgresMessage { public struct ParameterStatus: PostgresMessageType, CustomStringConvertible { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ParameterStatus { - guard let parameter = buffer.psqlReadNullTerminatedString() else { + guard let parameter = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Could not read parameter from parameter status message") } - guard let value = buffer.psqlReadNullTerminatedString() else { + guard let value = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Could not read value from parameter status message") } return .init(parameter: parameter, value: value) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift index cddaac1d..48a90c18 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift @@ -11,7 +11,7 @@ extension PostgresMessage { /// Describes a single field returns in a `RowDescription` message. public struct Field: CustomStringConvertible { static func parse(from buffer: inout ByteBuffer) throws -> Field { - guard let name = buffer.psqlReadNullTerminatedString() else { + guard let name = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Could not read row description field name") } guard let tableOID = buffer.readInteger(as: UInt32.self) else { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift b/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift index 66b4cb5f..553edc2c 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift @@ -38,7 +38,7 @@ extension PostgresMessage { public let initialData: [UInt8] public static func parse(from buffer: inout ByteBuffer) throws -> PostgresMessage.SASLInitialResponse { - guard let mechanism = buffer.psqlReadNullTerminatedString() else { + guard let mechanism = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Could not parse SASL mechanism from initial response message") } guard let dataLength = buffer.readInteger(as: Int32.self) else { @@ -57,7 +57,7 @@ extension PostgresMessage { } public func serialize(into buffer: inout ByteBuffer) throws { - buffer.psqlWriteNullTerminatedString(mechanism) + buffer.writeNullTerminatedString(mechanism) if initialData.count > 0 { buffer.writeInteger(Int32(initialData.count), as: Int32.self) // write(array:) writes Int16, which is incorrect here buffer.writeBytes(initialData) diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index d2211885..07e67c2d 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -108,30 +108,25 @@ extension Array: PSQLDecodable where Element: PSQLArrayElement { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - guard let isNotEmpty = buffer.readInteger(as: Int32.self), (0...1).contains(isNotEmpty) else { + guard let (isNotEmpty, b, element) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32, Int32).self), + 0 <= isNotEmpty, isNotEmpty <= 1, b == 0 + else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - guard let b = buffer.readInteger(as: Int32.self), b == 0 else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) - } - - guard let elementType = buffer.readInteger(as: PSQLDataType.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) - } + let elementType = PSQLDataType(rawValue: element) guard isNotEmpty == 1 else { return [] } - guard let expectedArrayCount = buffer.readInteger(as: Int32.self), expectedArrayCount > 0 else { + guard let (expectedArrayCount, dimensions) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32).self), + expectedArrayCount > 0, + dimensions == 1 + else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - - guard let dimensions = buffer.readInteger(as: Int32.self), dimensions == 1 else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) - } - + var result = Array() result.reserveCapacity(Int(expectedArrayCount)) diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 79d5256e..a948b41b 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -1,19 +1,6 @@ import NIOCore internal extension ByteBuffer { - mutating func psqlWriteNullTerminatedString(_ string: String) { - self.writeString(string) - self.writeInteger(0, as: UInt8.self) - } - - mutating func psqlReadNullTerminatedString() -> String? { - guard let nullIndex = readableBytesView.firstIndex(of: 0) else { - return nil - } - - defer { moveReaderIndex(forwardBy: 1) } - return readString(length: nullIndex - readerIndex) - } mutating func psqlWriteBackendMessageID(_ messageID: PSQLBackendMessage.ID) { self.writeInteger(messageID.rawValue) diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift index 92b000a0..54d7c6ad 100644 --- a/Sources/PostgresNIO/New/Messages/Authentication.swift +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -16,10 +16,7 @@ extension PSQLBackendMessage { case saslFinal(data: ByteBuffer) static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.psqlEnsureAtLeastNBytesRemaining(2) - - // we have at least two bytes remaining, therefore we can force unwrap this read. - let authID = buffer.readInteger(as: Int32.self)! + let authID = try buffer.throwingReadInteger(as: Int32.self) switch authID { case 0: @@ -29,12 +26,10 @@ extension PSQLBackendMessage { case 3: return .plaintext case 5: - try buffer.psqlEnsureExactNBytesRemaining(4) - let salt1 = buffer.readInteger(as: UInt8.self)! - let salt2 = buffer.readInteger(as: UInt8.self)! - let salt3 = buffer.readInteger(as: UInt8.self)! - let salt4 = buffer.readInteger(as: UInt8.self)! - return .md5(salt: (salt1, salt2, salt3, salt4)) + guard let salt = buffer.readMultipleIntegers(endianness: .big, as: (UInt8, UInt8, UInt8, UInt8).self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(4, actual: buffer.readableBytes) + } + return .md5(salt: salt) case 6: return .scmCredential case 7: @@ -47,7 +42,7 @@ extension PSQLBackendMessage { case 10: var names = [String]() let endIndex = buffer.readerIndex + buffer.readableBytes - while buffer.readerIndex < endIndex, let next = buffer.psqlReadNullTerminatedString() { + while buffer.readerIndex < endIndex, let next = buffer.readNullTerminatedString() { names.append(next) } diff --git a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift index fdc41439..2d6a23a4 100644 --- a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift +++ b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift @@ -7,14 +7,10 @@ extension PSQLBackendMessage { let secretKey: Int32 static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.psqlEnsureExactNBytesRemaining(8) - - // We have verified the correct length before, this means we have exactly eight bytes - // to read. If we have enough readable bytes, a read of Int32 should always succeed. - // Therefore we can force unwrap here. - let processID = buffer.readInteger(as: Int32.self)! - let secretKey = buffer.readInteger(as: Int32.self)! - + guard let (processID, secretKey) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32).self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(8, actual: buffer.readableBytes) + } + return .init(processID: processID, secretKey: secretKey) } } diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index dd3465b2..110d7866 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -13,8 +13,8 @@ extension PSQLFrontendMessage { var parameters: [PSQLEncodable] func encode(into buffer: inout ByteBuffer, using jsonEncoder: PSQLJSONEncoder) throws { - buffer.psqlWriteNullTerminatedString(self.portalName) - buffer.psqlWriteNullTerminatedString(self.preparedStatementName) + buffer.writeNullTerminatedString(self.portalName) + buffer.writeNullTerminatedString(self.preparedStatementName) // The number of parameter format codes that follow (denoted C below). This can be // zero to indicate that there are no parameters or that the parameters all use the diff --git a/Sources/PostgresNIO/New/Messages/Cancel.swift b/Sources/PostgresNIO/New/Messages/Cancel.swift index d2756580..64107d7a 100644 --- a/Sources/PostgresNIO/New/Messages/Cancel.swift +++ b/Sources/PostgresNIO/New/Messages/Cancel.swift @@ -15,9 +15,7 @@ extension PSQLFrontendMessage { let secretKey: Int32 func encode(into buffer: inout ByteBuffer) { - buffer.writeInteger(self.cancelRequestCode) - buffer.writeInteger(self.processID) - buffer.writeInteger(self.secretKey) + buffer.writeMultipleIntegers(self.cancelRequestCode, self.processID, self.secretKey) } } } diff --git a/Sources/PostgresNIO/New/Messages/Close.swift b/Sources/PostgresNIO/New/Messages/Close.swift index ae70f758..5ed532e6 100644 --- a/Sources/PostgresNIO/New/Messages/Close.swift +++ b/Sources/PostgresNIO/New/Messages/Close.swift @@ -10,10 +10,10 @@ extension PSQLFrontendMessage { switch self { case .preparedStatement(let name): buffer.writeInteger(UInt8(ascii: "S")) - buffer.psqlWriteNullTerminatedString(name) + buffer.writeNullTerminatedString(name) case .portal(let name): buffer.writeInteger(UInt8(ascii: "P")) - buffer.psqlWriteNullTerminatedString(name) + buffer.writeNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index 1828128b..31148c20 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -15,25 +15,19 @@ struct DataRow: PSQLBackendMessage.PayloadDecodable, Equatable { var bytes: ByteBuffer static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.psqlEnsureAtLeastNBytesRemaining(2) - let columnCount = buffer.readInteger(as: Int16.self)! + let columnCount = try buffer.throwingReadInteger(as: Int16.self) let firstColumnIndex = buffer.readerIndex for _ in 0..= 0 else { // if buffer length is negative, this means that the value is null continue } - - try buffer.psqlEnsureAtLeastNBytesRemaining(bufferLength) - buffer.moveReaderIndex(forwardBy: bufferLength) + + try buffer.throwingMoveReaderIndex(forwardBy: Int(bufferLength)) } - try buffer.psqlEnsureExactNBytesRemaining(0) - buffer.moveReaderIndex(to: firstColumnIndex) let columnSlice = buffer.readSlice(length: buffer.readableBytes)! return DataRow(columnCount: columnCount, bytes: columnSlice) diff --git a/Sources/PostgresNIO/New/Messages/Describe.swift b/Sources/PostgresNIO/New/Messages/Describe.swift index 104d7127..0a3105cc 100644 --- a/Sources/PostgresNIO/New/Messages/Describe.swift +++ b/Sources/PostgresNIO/New/Messages/Describe.swift @@ -11,10 +11,10 @@ extension PSQLFrontendMessage { switch self { case .preparedStatement(let name): buffer.writeInteger(UInt8(ascii: "S")) - buffer.psqlWriteNullTerminatedString(name) + buffer.writeNullTerminatedString(name) case .portal(let name): buffer.writeInteger(UInt8(ascii: "P")) - buffer.psqlWriteNullTerminatedString(name) + buffer.writeNullTerminatedString(name) } } } diff --git a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift index 891c7e9b..254cdf0f 100644 --- a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift +++ b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift @@ -117,7 +117,7 @@ extension PSQLBackendMessage.PayloadDecodable where Self: PSQLMessageNotice { asType: PSQLBackendMessage.Field.self) } - guard let string = buffer.psqlReadNullTerminatedString() else { + guard let string = buffer.readNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } fields[field] = string diff --git a/Sources/PostgresNIO/New/Messages/Execute.swift b/Sources/PostgresNIO/New/Messages/Execute.swift index 2cf13922..891bd9aa 100644 --- a/Sources/PostgresNIO/New/Messages/Execute.swift +++ b/Sources/PostgresNIO/New/Messages/Execute.swift @@ -15,7 +15,7 @@ extension PSQLFrontendMessage { } func encode(into buffer: inout ByteBuffer) { - buffer.psqlWriteNullTerminatedString(self.portalName) + buffer.writeNullTerminatedString(self.portalName) buffer.writeInteger(self.maxNumberOfRows) } } diff --git a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift index afc860fc..dd5c0cf2 100644 --- a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift +++ b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift @@ -8,13 +8,12 @@ extension PSQLBackendMessage { let payload: String static func decode(from buffer: inout ByteBuffer) throws -> PSQLBackendMessage.NotificationResponse { - try buffer.psqlEnsureAtLeastNBytesRemaining(6) - let backendPID = buffer.readInteger(as: Int32.self)! + let backendPID = try buffer.throwingReadInteger(as: Int32.self) - guard let channel = buffer.psqlReadNullTerminatedString() else { + guard let channel = buffer.readNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } - guard let payload = buffer.psqlReadNullTerminatedString() else { + guard let payload = buffer.readNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index 49062fda..971b3ac7 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -7,20 +7,16 @@ extension PSQLBackendMessage { var dataTypes: [PSQLDataType] static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.psqlEnsureAtLeastNBytesRemaining(2) - - let parameterCount = buffer.readInteger(as: Int16.self)! + let parameterCount = try buffer.throwingReadInteger(as: Int16.self) guard parameterCount >= 0 else { throw PSQLPartialDecodingError.integerMustBePositiveOrNull(parameterCount) } - try buffer.psqlEnsureExactNBytesRemaining(Int(parameterCount) * 4) - var result = [PSQLDataType]() result.reserveCapacity(Int(parameterCount)) for _ in 0.. Self { - guard let name = buffer.psqlReadNullTerminatedString() else { + guard let name = buffer.readNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } - guard let value = buffer.psqlReadNullTerminatedString() else { + guard let value = buffer.readNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } diff --git a/Sources/PostgresNIO/New/Messages/Parse.swift b/Sources/PostgresNIO/New/Messages/Parse.swift index 72eb4962..1d0aec19 100644 --- a/Sources/PostgresNIO/New/Messages/Parse.swift +++ b/Sources/PostgresNIO/New/Messages/Parse.swift @@ -13,8 +13,8 @@ extension PSQLFrontendMessage { let parameters: [PSQLDataType] func encode(into buffer: inout ByteBuffer) { - buffer.psqlWriteNullTerminatedString(self.preparedStatementName) - buffer.psqlWriteNullTerminatedString(self.query) + buffer.writeNullTerminatedString(self.preparedStatementName) + buffer.writeNullTerminatedString(self.query) buffer.writeInteger(Int16(self.parameters.count)) self.parameters.forEach { dataType in diff --git a/Sources/PostgresNIO/New/Messages/Password.swift b/Sources/PostgresNIO/New/Messages/Password.swift index df1bd327..88e885f9 100644 --- a/Sources/PostgresNIO/New/Messages/Password.swift +++ b/Sources/PostgresNIO/New/Messages/Password.swift @@ -6,7 +6,7 @@ extension PSQLFrontendMessage { let value: String func encode(into buffer: inout ByteBuffer) { - buffer.psqlWriteNullTerminatedString(value) + buffer.writeNullTerminatedString(value) } } diff --git a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift index 74b30200..b8fff2aa 100644 --- a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift +++ b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift @@ -33,10 +33,7 @@ extension PSQLBackendMessage { } static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.psqlEnsureExactNBytesRemaining(1) - - // Exactly one byte is readable. For this reason, we can force unwrap the UInt8 below - let value = buffer.readInteger(as: UInt8.self)! + let value = try buffer.throwingReadInteger(as: UInt8.self) guard let state = Self.init(rawValue: value) else { throw PSQLPartialDecodingError.valueNotRawRepresentable(value: value, asType: TransactionState.self) } diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index ade0e85c..4f470847 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -37,8 +37,7 @@ struct RowDescription: PSQLBackendMessage.PayloadDecodable, Equatable { } static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.psqlEnsureAtLeastNBytesRemaining(2) - let columnCount = buffer.readInteger(as: Int16.self)! + let columnCount = try buffer.throwingReadInteger(as: Int16.self) guard columnCount >= 0 else { throw PSQLPartialDecodingError.integerMustBePositiveOrNull(columnCount) @@ -48,18 +47,15 @@ struct RowDescription: PSQLBackendMessage.PayloadDecodable, Equatable { result.reserveCapacity(Int(columnCount)) for _ in 0.. 0 { buffer.writeInteger(Int32(self.initialData.count)) diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift index 0ceb1050..6e991928 100644 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ b/Sources/PostgresNIO/New/Messages/Startup.swift @@ -51,29 +51,26 @@ extension PSQLFrontendMessage { /// Serializes this message into a byte buffer. func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.protocolVersion) - buffer.psqlWriteNullTerminatedString("user") - buffer.writeString(self.parameters.user) - buffer.writeInteger(UInt8(0)) + buffer.writeNullTerminatedString("user") + buffer.writeNullTerminatedString(self.parameters.user) if let database = self.parameters.database { - buffer.psqlWriteNullTerminatedString("database") - buffer.writeString(database) - buffer.writeInteger(UInt8(0)) + buffer.writeNullTerminatedString("database") + buffer.writeNullTerminatedString(database) } if let options = self.parameters.options { - buffer.psqlWriteNullTerminatedString("options") - buffer.writeString(options) - buffer.writeInteger(UInt8(0)) + buffer.writeNullTerminatedString("options") + buffer.writeNullTerminatedString(options) } switch self.parameters.replication { case .database: - buffer.psqlWriteNullTerminatedString("replication") - buffer.psqlWriteNullTerminatedString("replication") + buffer.writeNullTerminatedString("replication") + buffer.writeNullTerminatedString("replication") case .true: - buffer.psqlWriteNullTerminatedString("replication") - buffer.psqlWriteNullTerminatedString("true") + buffer.writeNullTerminatedString("replication") + buffer.writeNullTerminatedString("true") case .false: break } diff --git a/Sources/PostgresNIO/New/PSQLBackendMessage.swift b/Sources/PostgresNIO/New/PSQLBackendMessage.swift index c71789f1..77f7b78b 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessage.swift +++ b/Sources/PostgresNIO/New/PSQLBackendMessage.swift @@ -190,47 +190,58 @@ extension PSQLBackendMessage { switch messageID { case .authentication: return try .authentication(.decode(from: &buffer)) + case .backendKeyData: return try .backendKeyData(.decode(from: &buffer)) + case .bindComplete: - try buffer.psqlEnsureExactNBytesRemaining(0) return .bindComplete + case .closeComplete: - try buffer.psqlEnsureExactNBytesRemaining(0) return .closeComplete + case .commandComplete: - guard let commandTag = buffer.psqlReadNullTerminatedString() else { + guard let commandTag = buffer.readNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } return .commandComplete(commandTag) + case .dataRow: return try .dataRow(.decode(from: &buffer)) + case .emptyQueryResponse: - try buffer.psqlEnsureExactNBytesRemaining(0) return .emptyQueryResponse + case .parameterStatus: return try .parameterStatus(.decode(from: &buffer)) + case .error: return try .error(.decode(from: &buffer)) + case .noData: - try buffer.psqlEnsureExactNBytesRemaining(0) return .noData + case .noticeResponse: return try .notice(.decode(from: &buffer)) + case .notificationResponse: return try .notification(.decode(from: &buffer)) + case .parameterDescription: return try .parameterDescription(.decode(from: &buffer)) + case .parseComplete: - try buffer.psqlEnsureExactNBytesRemaining(0) return .parseComplete + case .portalSuspended: - try buffer.psqlEnsureExactNBytesRemaining(0) return .portalSuspended + case .readyForQuery: return try .readyForQuery(.decode(from: &buffer)) + case .rowDescription: return try .rowDescription(.decode(from: &buffer)) + case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion: preconditionFailure() } diff --git a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift index dd4e4ebf..47485a7b 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift @@ -8,68 +8,66 @@ struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { } mutating func decode(buffer: inout ByteBuffer) throws -> PSQLBackendMessage? { - // make sure we have at least one byte to read - guard buffer.readableBytes > 0 else { - return nil - } if !self.hasAlreadyReceivedBytes { // We have not received any bytes yet! Let's peek at the first message id. If it // is a "S" or "N" we assume that it is connected to an SSL upgrade request. All // other messages that we expect now, don't start with either "S" or "N" - // we made sure, we have at least one byte available, above, thus force unwrap is okay - let firstByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self)! + let startReaderIndex = buffer.readerIndex + guard let firstByte = buffer.readInteger(as: UInt8.self) else { + return nil + } switch firstByte { case UInt8(ascii: "S"): - // mark byte as read - buffer.moveReaderIndex(forwardBy: 1) self.hasAlreadyReceivedBytes = true return .sslSupported + case UInt8(ascii: "N"): - // mark byte as read - buffer.moveReaderIndex(forwardBy: 1) self.hasAlreadyReceivedBytes = true return .sslUnsupported + default: + // move reader index back + buffer.moveReaderIndex(to: startReaderIndex) self.hasAlreadyReceivedBytes = true } } - // all other packages have an Int32 after the identifier that determines their length. + // all other packages start with a MessageID (UInt8) and their message length (UInt32). // do we have enough bytes for that? - guard buffer.readableBytes >= 5 else { + let startReaderIndex = buffer.readerIndex + guard let (idByte, length) = buffer.readMultipleIntegers(endianness: .big, as: (UInt8, UInt32).self) else { + // if this fails, the readerIndex wasn't changed return nil } - let idByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self)! - let length = buffer.getInteger(at: buffer.readerIndex + 1, as: Int32.self)! - - guard length + 1 <= buffer.readableBytes else { + // 1. try to read the message + guard var message = buffer.readSlice(length: Int(length) - 4) else { + // we need to move the reader index back to its start point + buffer.moveReaderIndex(to: startReaderIndex) return nil } - // At this point we are sure, that we have enough bytes to decode the next message. - // 1. Create a byteBuffer that represents exactly the next message. This can be force - // unwrapped, since it was verified that enough bytes are available. - let completeMessageBuffer = buffer.readSlice(length: 1 + Int(length))! - // 2. make sure we have a known message identifier guard let messageID = PSQLBackendMessage.ID(rawValue: idByte) else { - throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer) + buffer.moveReaderIndex(to: startReaderIndex) + let completeMessage = buffer.readSlice(length: Int(length) + 1)! + throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessage) } // 3. decode the message do { - // get a mutable byteBuffer copy - var slice = completeMessageBuffer - // move reader index forward by five bytes - slice.moveReaderIndex(forwardBy: 5) - - return try PSQLBackendMessage.decode(from: &slice, for: messageID) + let result = try PSQLBackendMessage.decode(from: &message, for: messageID) + if message.readableBytes > 0 { + throw PSQLPartialDecodingError.expectedExactlyNRemainingBytes(0, actual: message.readableBytes) + } + return result } catch let error as PSQLPartialDecodingError { - throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer) + buffer.moveReaderIndex(to: startReaderIndex) + let completeMessage = buffer.readSlice(length: Int(length) + 1)! + throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessage) } catch { preconditionFailure("Expected to only see `PartialDecodingError`s here.") } @@ -192,15 +190,16 @@ struct PSQLPartialDecodingError: Error { } extension ByteBuffer { - func psqlEnsureAtLeastNBytesRemaining(_ n: Int, file: String = #file, line: Int = #line) throws { - guard self.readableBytes >= n else { - throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: self.readableBytes, file: file, line: line) + mutating func throwingReadInteger(as: I.Type, file: String = #file, line: Int = #line) throws -> I { + guard let result = self.readInteger(endianness: .big, as: I.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(MemoryLayout.size, actual: self.readableBytes, file: file, line: line) } + return result } - - func psqlEnsureExactNBytesRemaining(_ n: Int, file: String = #file, line: Int = #line) throws { - guard self.readableBytes == n else { - throw PSQLPartialDecodingError.expectedExactlyNRemainingBytes(n, actual: self.readableBytes, file: file, line: line) + + mutating func throwingMoveReaderIndex(forwardBy offset: Int, file: String = #file, line: Int = #line) throws { + guard self.readSlice(length: offset) != nil else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(offset, actual: self.readableBytes, file: file, line: line) } } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 6c1be6f5..8ef8033c 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -68,7 +68,7 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { var string: String init(_ string: String) { self.string = string } func encode(into buffer: inout ByteBuffer) { - buffer.psqlWriteNullTerminatedString(self.string) + buffer.writeNullTerminatedString(self.string) } } @@ -166,7 +166,7 @@ extension PSQLBackendMessage.Authentication: PSQLMessagePayloadEncodable { case .sasl(names: let names): buffer.writeInteger(Int32(10)) for name in names { - buffer.psqlWriteNullTerminatedString(name) + buffer.writeNullTerminatedString(name) } case .saslContinue(data: var data): @@ -199,7 +199,7 @@ extension PSQLBackendMessage.ErrorResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { for (key, value) in self.fields { buffer.writeInteger(key.rawValue, as: UInt8.self) - buffer.psqlWriteNullTerminatedString(value) + buffer.writeNullTerminatedString(value) } buffer.writeInteger(0, as: UInt8.self) // signal done } @@ -209,7 +209,7 @@ extension PSQLBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { for (key, value) in self.fields { buffer.writeInteger(key.rawValue, as: UInt8.self) - buffer.psqlWriteNullTerminatedString(value) + buffer.writeNullTerminatedString(value) } buffer.writeInteger(0, as: UInt8.self) // signal done } @@ -218,8 +218,8 @@ extension PSQLBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { extension PSQLBackendMessage.NotificationResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.backendPID) - buffer.psqlWriteNullTerminatedString(self.channel) - buffer.psqlWriteNullTerminatedString(self.payload) + buffer.writeNullTerminatedString(self.channel) + buffer.writeNullTerminatedString(self.payload) } } @@ -235,8 +235,8 @@ extension PSQLBackendMessage.ParameterDescription: PSQLMessagePayloadEncodable { extension PSQLBackendMessage.ParameterStatus: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { - buffer.psqlWriteNullTerminatedString(self.parameter) - buffer.psqlWriteNullTerminatedString(self.value) + buffer.writeNullTerminatedString(self.parameter) + buffer.writeNullTerminatedString(self.value) } } @@ -251,7 +251,7 @@ extension RowDescription: PSQLMessagePayloadEncodable { buffer.writeInteger(Int16(self.columns.count)) for column in self.columns { - buffer.psqlWriteNullTerminatedString(column.name) + buffer.writeNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 4bf988ae..c639f4b2 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -40,8 +40,8 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { var database: String? var options: String? - while let name = messageSlice.psqlReadNullTerminatedString(), messageSlice.readerIndex < finalIndex { - let value = messageSlice.psqlReadNullTerminatedString() + while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex { + let value = messageSlice.readNullTerminatedString() switch name { case "user": @@ -136,7 +136,7 @@ extension PSQLFrontendMessage { case .parse: preconditionFailure("TODO: Unimplemented") case .password: - guard let password = buffer.psqlReadNullTerminatedString() else { + guard let password = buffer.readNullTerminatedString() else { throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) } return .password(.init(value: password)) diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 234e1541..7a688d41 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -14,8 +14,8 @@ class BindTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, 37) XCTAssertEqual(PSQLFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) - XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) - XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) // the number of parameters XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) // all (two) parameters have the same format (binary) diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index 8f8af2bd..4df15896 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -14,7 +14,7 @@ class CloseTests: XCTestCase { XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("Hello", byteBuffer.psqlReadNullTerminatedString()) + XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } @@ -28,7 +28,7 @@ class CloseTests: XCTestCase { XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index fabb0e29..87f7d09b 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -14,7 +14,7 @@ class DescribeTests: XCTestCase { XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("Hello", byteBuffer.psqlReadNullTerminatedString()) + XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } @@ -28,7 +28,7 @@ class DescribeTests: XCTestCase { XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift index df0d63b0..bbc945e4 100644 --- a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift @@ -19,7 +19,7 @@ class ErrorResponseTests: XCTestCase { let buffer = ByteBuffer.backendMessage(id: .error) { buffer in fields.forEach { (key, value) in buffer.writeInteger(key.rawValue, as: UInt8.self) - buffer.psqlWriteNullTerminatedString(value) + buffer.writeNullTerminatedString(value) } buffer.writeInteger(0, as: UInt8.self) // signal done } diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 0969194c..3ce8d63d 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -13,7 +13,7 @@ class ExecuteTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) XCTAssertEqual(PSQLFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length - XCTAssertEqual("", byteBuffer.psqlReadNullTerminatedString()) + XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual(0, byteBuffer.readInteger(as: Int32.self)) } } diff --git a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift index abf6b4ed..39fbb220 100644 --- a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift @@ -20,8 +20,8 @@ class NotificationResponseTests: XCTestCase { buffer.writeBackendMessage(id: .notificationResponse) { buffer in buffer.writeInteger(notification.backendPID) - buffer.psqlWriteNullTerminatedString(notification.channel) - buffer.psqlWriteNullTerminatedString(notification.payload) + buffer.writeNullTerminatedString(notification.channel) + buffer.writeNullTerminatedString(notification.payload) } } @@ -49,7 +49,7 @@ class NotificationResponseTests: XCTestCase { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .notificationResponse) { buffer in buffer.writeInteger(Int32(123)) - buffer.psqlWriteNullTerminatedString("hello") + buffer.writeNullTerminatedString("hello") buffer.writeString("world") } diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift index 2f00fa53..db4963e0 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift @@ -27,8 +27,8 @@ class ParameterStatusTests: XCTestCase { switch message { case .parameterStatus(let parameterStatus): buffer.writeBackendMessage(id: .parameterStatus) { buffer in - buffer.psqlWriteNullTerminatedString(parameterStatus.parameter) - buffer.psqlWriteNullTerminatedString(parameterStatus.value) + buffer.writeNullTerminatedString(parameterStatus.parameter) + buffer.writeNullTerminatedString(parameterStatus.value) } case .backendKeyData(let backendKeyData): buffer.writeBackendMessage(id: .backendKeyData) { buffer in @@ -62,7 +62,7 @@ class ParameterStatusTests: XCTestCase { func testDecodeFailureBecauseOfMissingNullTerminationInValue() { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .parameterStatus) { buffer in - buffer.psqlWriteNullTerminatedString("DateStyle") + buffer.writeNullTerminatedString("DateStyle") buffer.writeString("ISO, MDY") } diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 3393e74d..c147b749 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -24,8 +24,8 @@ class ParseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.parse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), parse.preparedStatementName) - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), parse.query) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query) XCTAssertEqual(byteBuffer.readInteger(as: Int16.self), Int16(parse.parameters.count)) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.bool.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.int8.rawValue) diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index f7876426..73c464f3 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -16,6 +16,6 @@ class PasswordTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, expectedLength) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.password.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(expectedLength - 1)) // length - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") } } diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index ba759dc4..8eba059d 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -25,7 +25,7 @@ class RowDescriptionTests: XCTestCase { buffer.writeInteger(Int16(description.columns.count)) description.columns.forEach { column in - buffer.psqlWriteNullTerminatedString(column.name) + buffer.writeNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) @@ -70,7 +70,7 @@ class RowDescriptionTests: XCTestCase { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in - buffer.psqlWriteNullTerminatedString(column.name) + buffer.writeNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) @@ -93,7 +93,7 @@ class RowDescriptionTests: XCTestCase { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in buffer.writeInteger(Int16(1)) - buffer.psqlWriteNullTerminatedString(column.name) + buffer.writeNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) @@ -116,7 +116,7 @@ class RowDescriptionTests: XCTestCase { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .rowDescription) { buffer in buffer.writeInteger(Int16(-1)) - buffer.psqlWriteNullTerminatedString(column.name) + buffer.writeNullTerminatedString(column.name) buffer.writeInteger(column.tableOID) buffer.writeInteger(column.columnAttributeNumber) buffer.writeInteger(column.dataType.rawValue) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index 3c4ae4b3..af2459ac 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -23,7 +23,7 @@ class SASLInitialResponseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), sasl.saslMechanism) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(sasl.initialData.count)) XCTAssertEqual(byteBuffer.readBytes(length: sasl.initialData.count), sasl.initialData) XCTAssertEqual(byteBuffer.readableBytes, 0) @@ -48,7 +48,7 @@ class SASLInitialResponseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), sasl.saslMechanism) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index ee63ea1a..1224aede 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -29,15 +29,15 @@ class StartupTests: XCTestCase { let byteBufferLength = Int32(byteBuffer.readableBytes) XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) XCTAssertEqual(startup.protocolVersion, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "database") - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "abc123") - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "options") - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "some options") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some options") if replication != .false { - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), "replication") - XCTAssertEqual(byteBuffer.psqlReadNullTerminatedString(), replication.stringValue) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "replication") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), replication.stringValue) } XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index 0f486180..049e23d1 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -90,8 +90,8 @@ class PSQLBackendMessageTests: XCTestCase { parameterStatus.forEach { parameterStatus in buffer.writeBackendMessage(id: .parameterStatus) { buffer in - buffer.psqlWriteNullTerminatedString(parameterStatus.parameter) - buffer.psqlWriteNullTerminatedString(parameterStatus.value) + buffer.writeNullTerminatedString(parameterStatus.parameter) + buffer.writeNullTerminatedString(parameterStatus.value) } expectedMessages.append(.parameterStatus(parameterStatus)) @@ -132,7 +132,7 @@ class PSQLBackendMessageTests: XCTestCase { buffer.writeBackendMessage(id: .noticeResponse) { buffer in fields.forEach { (key, value) in buffer.writeInteger(key.rawValue, as: UInt8.self) - buffer.psqlWriteNullTerminatedString(value) + buffer.writeNullTerminatedString(value) } buffer.writeInteger(0, as: UInt8.self) // signal done } @@ -216,7 +216,7 @@ class PSQLBackendMessageTests: XCTestCase { } okBuffer.writeBackendMessage(id: .commandComplete) { buffer in - buffer.psqlWriteNullTerminatedString(commandTag) + buffer.writeNullTerminatedString(commandTag) } } From 780a510863bfc00b2239b649df4f403080c4bc9c Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 27 Nov 2021 00:30:22 +0100 Subject: [PATCH 042/292] Refactor PSQLRowStream to make async/await easier (#201) ### Motivation `PSQLRowStream`'s current implementation is interesting. It should be better tested and easier to follow for async/await support later. ### Changes - Make `PSQLRowStream`'s implementation more sensible - Add unit tests for `PSQLRowStream` ### Result Adding async/await support becomes easier. --- Sources/PostgresNIO/New/PSQLRow.swift | 6 + Sources/PostgresNIO/New/PSQLRowStream.swift | 204 ++++++----- .../New/PSQLRowStreamTests.swift | 323 ++++++++++++++++++ 3 files changed, 429 insertions(+), 104 deletions(-) create mode 100644 Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift index e7a6ed7e..99115d73 100644 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -16,6 +16,12 @@ struct PSQLRow { } } +extension PSQLRow: Equatable { + static func ==(lhs: Self, rhs: Self) -> Bool { + lhs.data == rhs.data && lhs.columns == rhs.columns + } +} + extension PSQLRow { /// Access the data in the provided column and decode it into the target type. /// diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index e3d74f16..54bc74fd 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -2,7 +2,6 @@ import NIOCore import Logging final class PSQLRowStream { - enum RowSource { case stream(PSQLRowsDataSource) case noRows(Result) @@ -11,23 +10,21 @@ final class PSQLRowStream { let eventLoop: EventLoop let logger: Logger - private enum UpstreamState { + private enum BufferState { case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) case finished(buffer: CircularBuffer, commandTag: String) case failure(Error) - case consumed(Result) - case modifying } private enum DownstreamState { - case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise) - case waitingForAll(EventLoopPromise<[PSQLRow]>) - case consuming + case waitingForConsumer(BufferState) + case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) + case waitingForAll([PSQLRow], EventLoopPromise<[PSQLRow]>, PSQLRowsDataSource) + case consumed(Result) } internal let rowDescription: [RowDescription.Column] private let lookupTable: [String: Int] - private var upstreamState: UpstreamState private var downstreamState: DownstreamState private let jsonDecoder: PSQLJSONDecoder @@ -36,23 +33,24 @@ final class PSQLRowStream { eventLoop: EventLoop, rowSource: RowSource) { - let buffer = CircularBuffer() - - self.downstreamState = .consuming + let bufferState: BufferState switch rowSource { case .stream(let dataSource): - self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + bufferState = .streaming(buffer: .init(), dataSource: dataSource) case .noRows(.success(let commandTag)): - self.upstreamState = .finished(buffer: .init(), commandTag: commandTag) + bufferState = .finished(buffer: .init(), commandTag: commandTag) case .noRows(.failure(let error)): - self.upstreamState = .failure(error) + bufferState = .failure(error) } + self.downstreamState = .waitingForConsumer(bufferState) + self.eventLoop = eventLoop self.logger = queryContext.logger self.jsonDecoder = queryContext.jsonDecoder self.rowDescription = rowDescription + var lookup = [String: Int]() lookup.reserveCapacity(rowDescription.count) rowDescription.enumerated().forEach { (index, column) in @@ -60,6 +58,8 @@ final class PSQLRowStream { } self.lookupTable = lookup } + + // MARK: Consume in array func all() -> EventLoopFuture<[PSQLRow]> { if self.eventLoop.inEventLoop { @@ -74,40 +74,37 @@ final class PSQLRowStream { private func all0() -> EventLoopFuture<[PSQLRow]> { self.eventLoop.preconditionInEventLoop() - guard case .consuming = self.downstreamState else { - preconditionFailure("Invalid state") + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") } - switch self.upstreamState { - case .streaming(_, let dataSource): - dataSource.request(for: self) + switch bufferState { + case .streaming(let bufferedRows, let dataSource): let promise = self.eventLoop.makePromise(of: [PSQLRow].self) - self.downstreamState = .waitingForAll(promise) + let rows = bufferedRows.map { data in + PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + } + self.downstreamState = .waitingForAll(rows, promise, dataSource) + // immediately request more + dataSource.request(for: self) return promise.futureResult case .finished(let buffer, let commandTag): - self.upstreamState = .modifying - let rows = buffer.map { PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) } - self.downstreamState = .consuming - self.upstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(commandTag)) return self.eventLoop.makeSucceededFuture(rows) - case .consumed: - preconditionFailure("We already signaled, that the stream has completed, why are we asked again?") - - case .modifying: - preconditionFailure("Invalid state") - case .failure(let error): - self.upstreamState = .consumed(.failure(error)) + self.downstreamState = .consumed(.failure(error)) return self.eventLoop.makeFailedFuture(error) } } + // MARK: Consume on EventLoop + func onRow(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture { if self.eventLoop.inEventLoop { return self.onRow0(onRow) @@ -121,7 +118,11 @@ final class PSQLRowStream { private func onRow0(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture { self.eventLoop.preconditionInEventLoop() - switch self.upstreamState { + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") + } + + switch bufferState { case .streaming(var buffer, let dataSource): let promise = self.eventLoop.makePromise(of: Void.self) do { @@ -136,12 +137,11 @@ final class PSQLRowStream { } buffer.removeAll() - self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) - self.downstreamState = .iteratingRows(onRow: onRow, promise) + self.downstreamState = .iteratingRows(onRow: onRow, promise, dataSource) // immediately request more dataSource.request(for: self) } catch { - self.upstreamState = .failure(error) + self.downstreamState = .consumed(.failure(error)) dataSource.cancel(for: self) promise.fail(error) } @@ -160,22 +160,15 @@ final class PSQLRowStream { try onRow(row) } - self.upstreamState = .consumed(.success(commandTag)) - self.downstreamState = .consuming + self.downstreamState = .consumed(.success(commandTag)) return self.eventLoop.makeSucceededVoidFuture() } catch { - self.upstreamState = .consumed(.failure(error)) + self.downstreamState = .consumed(.failure(error)) return self.eventLoop.makeFailedFuture(error) } - case .consumed: - preconditionFailure("We already signaled, that the stream has completed, why are we asked again?") - - case .modifying: - preconditionFailure("Invalid state") - case .failure(let error): - self.upstreamState = .consumed(.failure(error)) + self.downstreamState = .consumed(.failure(error)) return self.eventLoop.makeFailedFuture(error) } } @@ -193,13 +186,15 @@ final class PSQLRowStream { "row_count": "\(newRows.count)" ]) - guard case .streaming(var buffer, let dataSource) = self.upstreamState else { - preconditionFailure("Invalid state") - } - switch self.downstreamState { - case .iteratingRows(let onRow, let promise): - precondition(buffer.isEmpty) + case .waitingForConsumer(.streaming(buffer: var buffer, dataSource: let dataSource)): + buffer.append(contentsOf: newRows) + self.downstreamState = .waitingForConsumer(.streaming(buffer: buffer, dataSource: dataSource)) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + preconditionFailure("How can new rows be received, if an end was already signalled?") + + case .iteratingRows(let onRow, let promise, let dataSource): do { for data in newRows { let row = PSQLRow( @@ -214,82 +209,83 @@ final class PSQLRowStream { dataSource.request(for: self) } catch { dataSource.cancel(for: self) - self.upstreamState = .failure(error) + self.downstreamState = .consumed(.failure(error)) promise.fail(error) return } - case .waitingForAll: - self.upstreamState = .modifying - buffer.append(contentsOf: newRows) - self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) - + + case .waitingForAll(var rows, let promise, let dataSource): + newRows.forEach { data in + let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + rows.append(row) + } + self.downstreamState = .waitingForAll(rows, promise, dataSource) // immediately request more dataSource.request(for: self) - case .consuming: - // this might happen, if the query has finished while the user is consuming data - // we don't need to ask for more since the user is consuming anyway - self.upstreamState = .modifying - buffer.append(contentsOf: newRows) - self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + case .consumed(.success): + preconditionFailure("How can we receive further rows, if we are supposed to be done") + + case .consumed(.failure): + break } } internal func receive(completion result: Result) { self.eventLoop.preconditionInEventLoop() - guard case .streaming(let oldBuffer, _) = self.upstreamState else { - preconditionFailure("Invalid state") + switch result { + case .success(let commandTag): + self.receiveEnd(commandTag) + case .failure(let error): + self.receiveError(error) } + } + private func receiveEnd(_ commandTag: String) { switch self.downstreamState { - case .iteratingRows(_, let promise): - precondition(oldBuffer.isEmpty) - self.downstreamState = .consuming - self.upstreamState = .consumed(result) - switch result { - case .success: - promise.succeed(()) - case .failure(let error): - promise.fail(error) - } + case .waitingForConsumer(.streaming(buffer: let buffer, _)): + self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag)) + case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + preconditionFailure("How can we get another end, if an end was already signalled?") - case .consuming: - switch result { - case .success(let commandTag): - self.upstreamState = .finished(buffer: oldBuffer, commandTag: commandTag) - case .failure(let error): - self.upstreamState = .failure(error) - } - - case .waitingForAll(let promise): - switch result { - case .failure(let error): - self.upstreamState = .consumed(.failure(error)) - promise.fail(error) - case .success(let commandTag): - let rows = oldBuffer.map { - PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) - } - self.upstreamState = .consumed(.success(commandTag)) - promise.succeed(rows) - } + case .iteratingRows(_, let promise, _): + self.downstreamState = .consumed(.success(commandTag)) + promise.succeed(()) + + case .waitingForAll(let rows, let promise, _): + self.downstreamState = .consumed(.success(commandTag)) + promise.succeed(rows) + + case .consumed: + break } } - - func cancel() { - guard case .streaming(_, let dataSource) = self.upstreamState else { - // We don't need to cancel any upstream resource. All needed data is already - // included in this - return - } - dataSource.cancel(for: self) + private func receiveError(_ error: Error) { + switch self.downstreamState { + case .waitingForConsumer(.streaming): + self.downstreamState = .waitingForConsumer(.failure(error)) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + preconditionFailure("How can we get another end, if an end was already signalled?") + + case .iteratingRows(_, let promise, _): + self.downstreamState = .consumed(.failure(error)) + promise.fail(error) + + case .waitingForAll(_, let promise, _): + self.downstreamState = .consumed(.failure(error)) + promise.fail(error) + + case .consumed: + break + } } var commandTag: String { - guard case .consumed(.success(let commandTag)) = self.upstreamState else { + guard case .consumed(.success(let commandTag)) = self.downstreamState else { preconditionFailure("commandTag may only be called if all rows have been consumed") } return commandTag diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift new file mode 100644 index 00000000..658f123f --- /dev/null +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -0,0 +1,323 @@ +import NIOCore +import Logging +import XCTest +@testable import PostgresNIO + +class PSQLRowStreamTests: XCTestCase { + func testEmptyStream() { + let logger = Logger(label: "test") + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + + let queryContext = ExtendedQueryContext( + query: "INSERT INTO foo bar;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + ) + + let stream = PSQLRowStream( + rowDescription: [], + queryContext: queryContext, + eventLoop: eventLoop, + rowSource: .noRows(.success("INSERT 0 1")) + ) + promise.succeed(stream) + + XCTAssertEqual(try stream.all().wait(), []) + XCTAssertEqual(stream.commandTag, "INSERT 0 1") + } + + func testFailedStream() { + let logger = Logger(label: "test") + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + + let queryContext = ExtendedQueryContext( + query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + ) + + let stream = PSQLRowStream( + rowDescription: [], + queryContext: queryContext, + eventLoop: eventLoop, + rowSource: .noRows(.failure(PSQLError.connectionClosed)) + ) + promise.succeed(stream) + + XCTAssertThrowsError(try stream.all().wait()) { + XCTAssertEqual($0 as? PSQLError, .connectionClosed) + } + } + + func testGetArrayAfterStreamHasFinished() { + let logger = Logger(label: "test") + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + + let queryContext = ExtendedQueryContext( + query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + ) + + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + rowDescription: [ + self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) + ], + queryContext: queryContext, + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + XCTAssertEqual(dataSource.hitDemand, 0, "Before we have a consumer demand is not signaled") + stream.receive(completion: .success("SELECT 2")) + + // attach consumer + let future = stream.all() + XCTAssertEqual(dataSource.hitDemand, 0) // TODO: Is this right? + + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try future.wait()) + XCTAssertEqual(rows?.count, 2) + } + + func testGetArrayBeforeStreamHasFinished() { + let logger = Logger(label: "test") + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + + let queryContext = ExtendedQueryContext( + query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + ) + + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + rowDescription: [ + self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) + ], + queryContext: queryContext, + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + XCTAssertEqual(dataSource.hitDemand, 0, "Before we have a consumer demand is not signaled") + + // attach consumer + let future = stream.all() + XCTAssertEqual(dataSource.hitDemand, 1) + + stream.receive([ + [ByteBuffer(string: "2")], + [ByteBuffer(string: "3")] + ]) + XCTAssertEqual(dataSource.hitDemand, 2) + + stream.receive([ + [ByteBuffer(string: "4")], + [ByteBuffer(string: "5")] + ]) + XCTAssertEqual(dataSource.hitDemand, 3) + + stream.receive(completion: .success("SELECT 2")) + + var rows: [PSQLRow]? + XCTAssertNoThrow(rows = try future.wait()) + XCTAssertEqual(rows?.count, 6) + } + + func testOnRowAfterStreamHasFinished() { + let logger = Logger(label: "test") + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + + let queryContext = ExtendedQueryContext( + query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + ) + + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + rowDescription: [ + self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) + ], + queryContext: queryContext, + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + stream.receive(completion: .success("SELECT 2")) + + XCTAssertEqual(dataSource.hitDemand, 0) + + // attach consumer + var counter = 0 + let future = stream.onRow { row in + XCTAssertEqual(try row.decode(column: 0, as: String.self), "\(counter)") + counter += 1 + } + XCTAssertEqual(counter, 2) + XCTAssertEqual(dataSource.hitDemand, 0) + + XCTAssertNoThrow(try future.wait()) + XCTAssertEqual(stream.commandTag, "SELECT 2") + } + + func testOnRowThrowsErrorOnInitialBatch() { + let logger = Logger(label: "test") + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + + let queryContext = ExtendedQueryContext( + query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + ) + + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + rowDescription: [ + self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) + ], + queryContext: queryContext, + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + stream.receive(completion: .success("SELECT 2")) + + XCTAssertEqual(dataSource.hitDemand, 0) + + // attach consumer + var counter = 0 + let future = stream.onRow { row in + XCTAssertEqual(try row.decode(column: 0, as: String.self), "\(counter)") + if counter == 1 { + throw OnRowError(row: counter) + } + counter += 1 + } + XCTAssertEqual(counter, 1) + XCTAssertEqual(dataSource.hitDemand, 0) + + XCTAssertThrowsError(try future.wait()) { + XCTAssertEqual($0 as? OnRowError, OnRowError(row: 1)) + } + } + + + func testOnRowBeforeStreamHasFinished() { + let logger = Logger(label: "test") + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + + let queryContext = ExtendedQueryContext( + query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + ) + + let dataSource = CountingDataSource() + let stream = PSQLRowStream( + rowDescription: [ + self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) + ], + queryContext: queryContext, + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + XCTAssertEqual(dataSource.hitDemand, 0) + XCTAssertEqual(dataSource.hitCancel, 0) + + stream.receive([ + [ByteBuffer(string: "0")], + [ByteBuffer(string: "1")] + ]) + + XCTAssertEqual(dataSource.hitDemand, 0, "Before we have a consumer demand is not signaled") + + // attach consumer + var counter = 0 + let future = stream.onRow { row in + XCTAssertEqual(try row.decode(column: 0, as: String.self), "\(counter)") + counter += 1 + } + XCTAssertEqual(counter, 2) + XCTAssertEqual(dataSource.hitDemand, 1) + + stream.receive([ + [ByteBuffer(string: "2")], + [ByteBuffer(string: "3")] + ]) + XCTAssertEqual(counter, 4) + XCTAssertEqual(dataSource.hitDemand, 2) + + stream.receive([ + [ByteBuffer(string: "4")], + [ByteBuffer(string: "5")] + ]) + XCTAssertEqual(counter, 6) + XCTAssertEqual(dataSource.hitDemand, 3) + + stream.receive(completion: .success("SELECT 6")) + + XCTAssertNoThrow(try future.wait()) + XCTAssertEqual(stream.commandTag, "SELECT 6") + } + + func makeColumnDescription(name: String, dataType: PSQLDataType, format: PSQLFormat) -> RowDescription.Column { + RowDescription.Column( + name: "test", + tableOID: 123, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: -1, + dataTypeModifier: 0, + format: .binary + ) + } +} + +private struct OnRowError: Error, Equatable { + var row: Int +} + +class CountingDataSource: PSQLRowsDataSource { + + var hitDemand: Int = 0 + var hitCancel: Int = 0 + + init() {} + + func cancel(for stream: PSQLRowStream) { + self.hitCancel += 1 + } + + func request(for stream: PSQLRowStream) { + self.hitDemand += 1 + } +} From 0157e1dfb6bebc5404e8608066e2e9ac0c65a46c Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Wed, 15 Dec 2021 06:47:46 -0600 Subject: [PATCH 043/292] Add NOTICE.txt --- NOTICE.txt | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 NOTICE.txt diff --git a/NOTICE.txt b/NOTICE.txt new file mode 100644 index 00000000..9547a780 --- /dev/null +++ b/NOTICE.txt @@ -0,0 +1,13 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Vapor open source project +// +// Copyright (c) 2017-2021 Vapor project authors +// Licensed under MIT +// +// See LICENSE for license information +// +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + From e231a57fbf8e8dced86d7a80bc02735948635043 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20L=C3=A9veill=C3=A9?= Date: Wed, 19 Jan 2022 11:05:35 -0500 Subject: [PATCH 044/292] Fix typo in README.md (#208) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c99ae224..79dfc669 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ PostgresNIO supports the following platforms: - Ubuntu 16.04+ - macOS 10.15+ -### Secrurity +### Security Please see [SECURITY.md](https://github.com/vapor/.github/blob/main/SECURITY.md) for details on the security process. From cc07811a28b7cb1bd5beabeb3de335674bc465e0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 17 Feb 2022 17:48:04 +0100 Subject: [PATCH 045/292] Remove PSQLJSONDecoder from PSQLConnection (#214) --- .../PostgresConnection+Connect.swift | 3 +-- Sources/PostgresNIO/New/PSQLConnection.swift | 14 +++-------- Sources/PostgresNIO/New/PSQLRow.swift | 25 +++++++++++++------ Sources/PostgresNIO/New/PSQLRowStream.swift | 17 +++++-------- Sources/PostgresNIO/New/PSQLTask.swift | 5 ---- .../ConnectionStateMachineTests.swift | 2 -- .../ExtendedQueryStateMachineTests.swift | 6 ++--- .../New/PSQLRowStreamTests.swift | 16 ++++++------ 8 files changed, 39 insertions(+), 49 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift index 49463aa5..518e9234 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift @@ -12,8 +12,7 @@ extension PostgresConnection { ) -> EventLoopFuture { let coders = PSQLConnection.Configuration.Coders( - jsonEncoder: PostgresJSONEncoderWrapper(_defaultJSONEncoder), - jsonDecoder: PostgresJSONDecoderWrapper(_defaultJSONDecoder) + jsonEncoder: PostgresJSONEncoderWrapper(_defaultJSONEncoder) ) let configuration = PSQLConnection.Configuration( diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 3d1d5f37..4f5d3f64 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -14,15 +14,13 @@ final class PSQLConnection { struct Coders { var jsonEncoder: PSQLJSONEncoder - var jsonDecoder: PSQLJSONDecoder - init(jsonEncoder: PSQLJSONEncoder, jsonDecoder: PSQLJSONDecoder) { + init(jsonEncoder: PSQLJSONEncoder) { self.jsonEncoder = jsonEncoder - self.jsonDecoder = jsonDecoder } static var foundation: Coders { - Coders(jsonEncoder: JSONEncoder(), jsonDecoder: JSONDecoder()) + Coders(jsonEncoder: JSONEncoder()) } } @@ -98,13 +96,11 @@ final class PSQLConnection { /// A logger to use in case private var logger: Logger let connectionID: String - let jsonDecoder: PSQLJSONDecoder - init(channel: Channel, connectionID: String, logger: Logger, jsonDecoder: PSQLJSONDecoder) { + init(channel: Channel, connectionID: String, logger: Logger) { self.channel = channel self.connectionID = connectionID self.logger = logger - self.jsonDecoder = jsonDecoder } deinit { assert(self.isClosed, "PostgresConnection deinitialized before being closed.") @@ -136,7 +132,6 @@ final class PSQLConnection { query: query, bind: bind, logger: logger, - jsonDecoder: self.jsonDecoder, promise: promise) self.channel.write(PSQLTask.extendedQuery(context), promise: nil) @@ -171,7 +166,6 @@ final class PSQLConnection { preparedStatement: preparedStatement, bind: bind, logger: logger, - jsonDecoder: self.jsonDecoder, promise: promise) self.channel.write(PSQLTask.extendedQuery(context), promise: nil) @@ -258,7 +252,7 @@ final class PSQLConnection { } }.map { _ in channel } }.map { channel in - PSQLConnection(channel: channel, connectionID: connectionID, logger: logger, jsonDecoder: configuration.coders.jsonDecoder) + PSQLConnection(channel: channel, connectionID: connectionID, logger: logger) }.flatMapErrorThrowing { error -> PSQLConnection in switch error { case is PSQLError: diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift index 99115d73..9fbc7f14 100644 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -1,4 +1,5 @@ import NIOCore +import Foundation /// `PSQLRow` represents a single row that was received from the Postgres Server. struct PSQLRow { @@ -6,13 +7,11 @@ struct PSQLRow { internal let data: DataRow internal let columns: [RowDescription.Column] - internal let jsonDecoder: PSQLJSONDecoder - internal init(data: DataRow, lookupTable: [String: Int], columns: [RowDescription.Column], jsonDecoder: PSQLJSONDecoder) { + internal init(data: DataRow, lookupTable: [String: Int], columns: [RowDescription.Column]) { self.data = data self.lookupTable = lookupTable self.columns = columns - self.jsonDecoder = jsonDecoder } } @@ -30,12 +29,12 @@ extension PSQLRow { /// - type: The type to decode the data into /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. /// - Returns: The decoded value of Type T. - func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + func decode(column: String, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { guard let index = self.lookupTable[column] else { preconditionFailure("A column '\(column)' does not exist.") } - return try self.decode(column: index, as: type, file: file, line: line) + return try self.decode(column: index, as: type, jsonDecoder: jsonDecoder, file: file, line: line) } /// Access the data in the provided column and decode it into the target type. @@ -45,12 +44,12 @@ extension PSQLRow { /// - type: The type to decode the data into /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. /// - Returns: The decoded value of Type T. - func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + func decode(column index: Int, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { precondition(index < self.data.columnCount) let column = self.columns[index] let context = PSQLDecodingContext( - jsonDecoder: self.jsonDecoder, + jsonDecoder: jsonDecoder, columnName: column.name, columnIndex: index, file: file, @@ -63,3 +62,15 @@ extension PSQLRow { return try T.decode(from: &cellSlice, type: column.dataType, format: column.format, context: context) } } + +extension PSQLRow { + // TODO: Remove this function. Only here to keep the tests running as of today. + func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + try self.decode(column: column, as: type, jsonDecoder: JSONDecoder(), file: file, line: line) + } + + // TODO: Remove this function. Only here to keep the tests running as of today. + func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + try self.decode(column: index, as: type, jsonDecoder: JSONDecoder(), file: file, line: line) + } +} diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 54bc74fd..3262e995 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -26,7 +26,6 @@ final class PSQLRowStream { internal let rowDescription: [RowDescription.Column] private let lookupTable: [String: Int] private var downstreamState: DownstreamState - private let jsonDecoder: PSQLJSONDecoder init(rowDescription: [RowDescription.Column], queryContext: ExtendedQueryContext, @@ -47,7 +46,6 @@ final class PSQLRowStream { self.eventLoop = eventLoop self.logger = queryContext.logger - self.jsonDecoder = queryContext.jsonDecoder self.rowDescription = rowDescription @@ -82,7 +80,7 @@ final class PSQLRowStream { case .streaming(let bufferedRows, let dataSource): let promise = self.eventLoop.makePromise(of: [PSQLRow].self) let rows = bufferedRows.map { data in - PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription) } self.downstreamState = .waitingForAll(rows, promise, dataSource) // immediately request more @@ -91,7 +89,7 @@ final class PSQLRowStream { case .finished(let buffer, let commandTag): let rows = buffer.map { - PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) } self.downstreamState = .consumed(.success(commandTag)) @@ -130,8 +128,7 @@ final class PSQLRowStream { let row = PSQLRow( data: data, lookupTable: self.lookupTable, - columns: self.rowDescription, - jsonDecoder: self.jsonDecoder + columns: self.rowDescription ) try onRow(row) } @@ -154,8 +151,7 @@ final class PSQLRowStream { let row = PSQLRow( data: data, lookupTable: self.lookupTable, - columns: self.rowDescription, - jsonDecoder: self.jsonDecoder + columns: self.rowDescription ) try onRow(row) } @@ -200,8 +196,7 @@ final class PSQLRowStream { let row = PSQLRow( data: data, lookupTable: self.lookupTable, - columns: self.rowDescription, - jsonDecoder: self.jsonDecoder + columns: self.rowDescription ) try onRow(row) } @@ -216,7 +211,7 @@ final class PSQLRowStream { case .waitingForAll(var rows, let promise, let dataSource): newRows.forEach { data in - let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription) rows.append(row) } self.downstreamState = .waitingForAll(rows, promise, dataSource) diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 1f7a06d6..0f0c6d04 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -28,26 +28,22 @@ final class ExtendedQueryContext { let bind: [PSQLEncodable] let logger: Logger - let jsonDecoder: PSQLJSONDecoder let promise: EventLoopPromise init(query: String, bind: [PSQLEncodable], logger: Logger, - jsonDecoder: PSQLJSONDecoder, promise: EventLoopPromise) { self.query = .unnamed(query) self.bind = bind self.logger = logger - self.jsonDecoder = jsonDecoder self.promise = promise } init(preparedStatement: PSQLPreparedStatement, bind: [PSQLEncodable], logger: Logger, - jsonDecoder: PSQLJSONDecoder, promise: EventLoopPromise) { self.query = .preparedStatement( @@ -55,7 +51,6 @@ final class ExtendedQueryContext { rowDescription: preparedStatement.rowDescription) self.bind = bind self.logger = logger - self.jsonDecoder = jsonDecoder self.promise = promise } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index e796c0f9..79dc27c4 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -124,7 +124,6 @@ class ConnectionStateMachineTests: XCTestCase { let authContext = AuthContext(username: "test", password: "abc123", database: "test") let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) - let jsonDecoder = JSONDecoder() let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self) var state = ConnectionStateMachine() @@ -132,7 +131,6 @@ class ConnectionStateMachineTests: XCTestCase { query: "Select version()", bind: [], logger: .psqlTest, - jsonDecoder: jsonDecoder, promise: queryPromise) XCTAssertEqual(state.enqueue(task: .extendedQuery(extendedQueryContext)), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 39360645..e3a3e515 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -13,7 +13,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "DELETE FROM table WHERE id=$0" - let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, jsonDecoder: JSONDecoder(), promise: promise) + let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, promise: promise) XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [1])) XCTAssertEqual(state.parseCompleteReceived(), .wait) @@ -31,7 +31,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let queryPromise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) queryPromise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "SELECT version()" - let queryContext = ExtendedQueryContext(query: query, bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: queryPromise) + let queryContext = ExtendedQueryContext(query: query, bind: [], logger: logger, promise: queryPromise) XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) @@ -85,7 +85,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let query = "DELETE FROM table WHERE id=$0" - let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, jsonDecoder: JSONDecoder(), promise: promise) + let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, promise: promise) XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [1])) XCTAssertEqual(state.parseCompleteReceived(), .wait) diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 658f123f..ea303f20 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -10,7 +10,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "INSERT INTO foo bar;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + query: "INSERT INTO foo bar;", bind: [], logger: logger, promise: promise ) let stream = PSQLRowStream( @@ -31,7 +31,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise ) let stream = PSQLRowStream( @@ -53,7 +53,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise ) let dataSource = CountingDataSource() @@ -92,9 +92,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise - ) - + query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise) let dataSource = CountingDataSource() let stream = PSQLRowStream( rowDescription: [ @@ -144,7 +142,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise ) let dataSource = CountingDataSource() @@ -188,7 +186,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise ) let dataSource = CountingDataSource() @@ -237,7 +235,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, jsonDecoder: JSONDecoder(), promise: promise + query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise ) let dataSource = CountingDataSource() From 55d6b9da969f61f96cf618dad0bf770aa4b174f0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 17 Feb 2022 19:11:19 +0100 Subject: [PATCH 046/292] Merge type PSQLFormat into PostgresFormat (#212) --- Sources/PostgresNIO/Data/PostgresData.swift | 4 ++-- .../PostgresNIO/Data/PostgresDataType.swift | 22 ++++++++++++++----- Sources/PostgresNIO/Data/PostgresRow.swift | 6 ++--- .../Message/PostgresMessage+Bind.swift | 4 ++-- .../PostgresMessage+RowDescription.swift | 4 ++-- .../New/Data/Array+PSQLCodable.swift | 4 ++-- .../New/Data/Bool+PSQLCodable.swift | 4 ++-- .../New/Data/Bytes+PSQLCodable.swift | 10 ++++----- .../New/Data/Date+PSQLCodable.swift | 4 ++-- .../New/Data/Decimal+PSQLCodable.swift | 4 ++-- .../New/Data/Float+PSQLCodable.swift | 8 +++---- .../New/Data/Int+PSQLCodable.swift | 20 ++++++++--------- .../New/Data/JSON+PSQLCodable.swift | 4 ++-- .../New/Data/Optional+PSQLCodable.swift | 4 ++-- .../Data/RawRepresentable+PSQLCodable.swift | 4 ++-- .../New/Data/String+PSQLCodable.swift | 4 ++-- .../New/Data/UUID+PSQLCodable.swift | 4 ++-- Sources/PostgresNIO/New/Messages/Bind.swift | 2 +- .../New/Messages/RowDescription.swift | 6 ++--- Sources/PostgresNIO/New/PSQLCodable.swift | 4 ++-- Sources/PostgresNIO/New/PSQLData.swift | 14 ++---------- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 8 +++---- .../New/Data/JSON+PSQLCodableTests.swift | 2 +- .../New/Data/UUID+PSQLCodableTests.swift | 2 +- .../New/PSQLRowStreamTests.swift | 2 +- 25 files changed, 77 insertions(+), 77 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresData.swift b/Sources/PostgresNIO/Data/PostgresData.swift index 916c27bd..96ac7023 100644 --- a/Sources/PostgresNIO/Data/PostgresData.swift +++ b/Sources/PostgresNIO/Data/PostgresData.swift @@ -16,11 +16,11 @@ public struct PostgresData: CustomStringConvertible, CustomDebugStringConvertibl /// Currently will be zero (text) or one (binary). /// In a RowDescription returned from the statement variant of Describe, /// the format code is not yet known and will always be zero. - public var formatCode: PostgresFormatCode + public var formatCode: PostgresFormat public var value: ByteBuffer? - public init(type: PostgresDataType, typeModifier: Int32? = nil, formatCode: PostgresFormatCode = .binary, value: ByteBuffer? = nil) { + public init(type: PostgresDataType, typeModifier: Int32? = nil, formatCode: PostgresFormat = .binary, value: ByteBuffer? = nil) { self.type = type self.typeModifier = typeModifier self.formatCode = formatCode diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index c9c96eb7..37520242 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -1,11 +1,14 @@ -/// The format code being used for the field. -/// Currently will be zero (text) or one (binary). -/// In a RowDescription returned from the statement variant of Describe, -/// the format code is not yet known and will always be zero. -public enum PostgresFormatCode: Int16, Codable, CustomStringConvertible { +/// The format the postgres types are encoded in on the wire. +/// +/// Currently there a two wire formats supported: +/// - text +/// - binary +public enum PostgresFormat: Int16 { case text = 0 case binary = 1 - +} + +extension PostgresFormat: CustomStringConvertible { public var description: String { switch self { case .text: return "text" @@ -14,6 +17,13 @@ public enum PostgresFormatCode: Int16, Codable, CustomStringConvertible { } } +// TODO: The Codable conformance does not make any sense. Let's remove this with next major break. +extension PostgresFormat: Codable {} + +// TODO: Renamed during 1.x. Remove this with next major break. +@available(*, deprecated, renamed: "PostgresFormat") +public typealias PostgresFormatCode = PostgresFormat + /// The data type's raw object ID. /// Use `select * from pg_type where oid = ;` to lookup more information. public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, CustomStringConvertible, RawRepresentable { diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 7c80fe91..7b08b360 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -1,7 +1,7 @@ public struct PostgresRow: CustomStringConvertible { final class LookupTable { let rowDescription: PostgresMessage.RowDescription - let resultFormat: [PostgresFormatCode] + let resultFormat: [PostgresFormat] struct Value { let index: Int @@ -27,7 +27,7 @@ public struct PostgresRow: CustomStringConvertible { init( rowDescription: PostgresMessage.RowDescription, - resultFormat: [PostgresFormatCode] + resultFormat: [PostgresFormat] ) { self.rowDescription = rowDescription self.resultFormat = resultFormat @@ -54,7 +54,7 @@ public struct PostgresRow: CustomStringConvertible { guard let entry = self.lookupTable.lookup(column: column) else { return nil } - let formatCode: PostgresFormatCode + let formatCode: PostgresFormat switch self.lookupTable.resultFormat.count { case 1: formatCode = self.lookupTable.resultFormat[0] default: formatCode = entry.field.formatCode diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift b/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift index a5687c40..ca8d4aa8 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift @@ -26,7 +26,7 @@ extension PostgresMessage { /// This can be zero to indicate that there are no parameters or that the parameters all use the default format (text); /// or one, in which case the specified format code is applied to all parameters; or it can equal the actual number of parameters. /// The parameter format codes. Each must presently be zero (text) or one (binary). - public var parameterFormatCodes: [PostgresFormatCode] + public var parameterFormatCodes: [PostgresFormat] /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. public var parameters: [Parameter] @@ -35,7 +35,7 @@ extension PostgresMessage { /// This can be zero to indicate that there are no result columns or that the result columns should all use the default format (text); /// or one, in which case the specified format code is applied to all result columns (if any); /// or it can equal the actual number of result columns of the query. - public var resultFormatCodes: [PostgresFormatCode] + public var resultFormatCodes: [PostgresFormat] /// Serializes this message into a byte buffer. public func serialize(into buffer: inout ByteBuffer) { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift index 48a90c18..ee8fa919 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift @@ -29,7 +29,7 @@ extension PostgresMessage { guard let dataTypeModifier = buffer.readInteger(as: Int32.self) else { throw PostgresError.protocol("Could not read row description field data type modifier") } - guard let formatCode = buffer.readInteger(as: PostgresFormatCode.self) else { + guard let formatCode = buffer.readInteger(as: PostgresFormat.self) else { throw PostgresError.protocol("Could not read row description field format code") } return .init( @@ -65,7 +65,7 @@ extension PostgresMessage { /// Currently will be zero (text) or one (binary). /// In a RowDescription returned from the statement variant of Describe, /// the format code is not yet known and will always be zero. - public var formatCode: PostgresFormatCode + public var formatCode: PostgresFormat /// See `CustomStringConvertible`. public var description: String { diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index 07e67c2d..d9371f47 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -72,7 +72,7 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { Element.psqlArrayType } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } @@ -102,7 +102,7 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { extension Array: PSQLDecodable where Element: PSQLArrayElement { - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Array { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Array { guard case .binary = format else { // currently we only support decoding arrays in binary format. throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift index 9ab2cc0f..5e097ac3 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift @@ -5,11 +5,11 @@ extension Bool: PSQLCodable { .bool } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Bool { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Bool { guard type == .bool else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift index be8b2dd8..22298026 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift @@ -7,7 +7,7 @@ extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { .bytea } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } @@ -21,7 +21,7 @@ extension ByteBuffer: PSQLCodable { .bytea } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } @@ -30,7 +30,7 @@ extension ByteBuffer: PSQLCodable { byteBuffer.writeBuffer(©OfSelf) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { return buffer } } @@ -40,7 +40,7 @@ extension Data: PSQLCodable { .bytea } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } @@ -48,7 +48,7 @@ extension Data: PSQLCodable { byteBuffer.writeBytes(self) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { return buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! } } diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift index f78a915b..7639cd66 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift @@ -6,11 +6,11 @@ extension Date: PSQLCodable { .timestamptz } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch type { case .timestamp, .timestamptz: guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { diff --git a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift index de42a874..d36f5b57 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift @@ -6,11 +6,11 @@ extension Decimal: PSQLCodable { .numeric } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Decimal { + static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Decimal { switch (format, type) { case (.binary, .numeric): guard let numeric = PostgresNumeric(buffer: &byteBuffer) else { diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index 6a551e64..d4560dc3 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -5,11 +5,11 @@ extension Float: PSQLCodable { .float4 } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Float { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Float { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { @@ -41,11 +41,11 @@ extension Double: PSQLCodable { .float8 } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Double { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Double { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift index 2c421e92..abd5d19d 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift @@ -5,12 +5,12 @@ extension UInt8: PSQLCodable { .char } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch type { case .bpchar, .char: guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { @@ -35,12 +35,12 @@ extension Int16: PSQLCodable { .int2 } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -68,12 +68,12 @@ extension Int32: PSQLCodable { .int4 } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -106,12 +106,12 @@ extension Int64: PSQLCodable { .int8 } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -156,12 +156,12 @@ extension Int: PSQLCodable { } } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { diff --git a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift index 0a321003..3f9b1093 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift @@ -10,11 +10,11 @@ extension PSQLCodable where Self: Codable { .jsonb } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .jsonb): guard JSONBVersionByte == buffer.readInteger(as: UInt8.self) else { diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index 99332221..fa19df26 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -1,7 +1,7 @@ import NIOCore extension Optional: PSQLDecodable where Wrapped: PSQLDecodable { - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Optional { + static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Optional { preconditionFailure("This code path should never be hit.") // The code path for decoding an optional should be: // -> PSQLData.decode(as: String?.self) @@ -20,7 +20,7 @@ extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { } } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { switch self { case .some(let value): return value.psqlFormat diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift index 02bafa39..367fa45a 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift @@ -5,11 +5,11 @@ extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { self.rawValue.psqlType } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { self.rawValue.psqlFormat } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { guard let rawValue = try? RawValue.decode(from: &buffer, type: type, format: format, context: context), let selfValue = Self.init(rawValue: rawValue) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift index cff48330..970f7e48 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift @@ -6,7 +6,7 @@ extension String: PSQLCodable { .text } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } @@ -14,7 +14,7 @@ extension String: PSQLCodable { byteBuffer.writeString(self) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> String { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> String { switch (format, type) { case (_, .varchar), (_, .text), diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift index 5e259c4b..eef54983 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift @@ -8,7 +8,7 @@ extension UUID: PSQLCodable { .uuid } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } @@ -22,7 +22,7 @@ extension UUID: PSQLCodable { ]) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> UUID { + static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> UUID { switch (format, type) { case (.binary, .uuid): guard let uuid = buffer.readUUID() else { diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index 110d7866..500a13b9 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -42,7 +42,7 @@ extension PSQLFrontendMessage { // result columns of the query. buffer.writeInteger(1, as: Int16.self) // The result-column format codes. Each must presently be zero (text) or one (binary). - buffer.writeInteger(PSQLFormat.binary.rawValue, as: Int16.self) + buffer.writeInteger(PostgresFormat.binary.rawValue, as: Int16.self) } } } diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index 4f470847..cac32eac 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -33,7 +33,7 @@ struct RowDescription: PSQLBackendMessage.PayloadDecodable, Equatable { /// The format being used for the field. Currently will be text or binary. In a RowDescription returned /// from the statement variant of Describe, the format code is not yet known and will always be text. - var format: PSQLFormat + var format: PostgresFormat } static func decode(from buffer: inout ByteBuffer) throws -> Self { @@ -57,8 +57,8 @@ struct RowDescription: PSQLBackendMessage.PayloadDecodable, Equatable { throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(18, actual: buffer.readableBytes) } - guard let format = PSQLFormat(rawValue: formatCodeInt16) else { - throw PSQLPartialDecodingError.valueNotRawRepresentable(value: formatCodeInt16, asType: PSQLFormat.self) + guard let format = PostgresFormat(rawValue: formatCodeInt16) else { + throw PSQLPartialDecodingError.valueNotRawRepresentable(value: formatCodeInt16, asType: PostgresFormat.self) } let field = Column( diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index b5434edd..c523eda8 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -6,7 +6,7 @@ protocol PSQLEncodable { var psqlType: PSQLDataType { get } /// identifies the postgres format that is used to encode the value into `byteBuffer` in `encode` - var psqlFormat: PSQLFormat { get } + var psqlFormat: PostgresFormat { get } /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting /// the byte count. This method is called from the default `encodeRaw` implementation. @@ -32,7 +32,7 @@ protocol PSQLDecodable { /// - context: A `PSQLDecodingContext` providing context for decoding. This includes a `JSONDecoder` /// to use when decoding json and metadata to create better errors. /// - Returns: A decoded object - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self + static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self } /// A type that can be encoded into and decoded from a postgres binary format diff --git a/Sources/PostgresNIO/New/PSQLData.swift b/Sources/PostgresNIO/New/PSQLData.swift index 840d798a..4d1c3acc 100644 --- a/Sources/PostgresNIO/New/PSQLData.swift +++ b/Sources/PostgresNIO/New/PSQLData.swift @@ -1,23 +1,13 @@ import NIOCore -/// The format the postgres types are encoded in on the wire. -/// -/// Currently there a two wire formats supported: -/// - text -/// - binary -enum PSQLFormat: Int16 { - case text = 0 - case binary = 1 -} - struct PSQLData: Equatable { @usableFromInline var bytes: ByteBuffer? @usableFromInline var dataType: PSQLDataType - @usableFromInline var format: PSQLFormat + @usableFromInline var format: PostgresFormat /// use this only for testing - init(bytes: ByteBuffer?, dataType: PSQLDataType, format: PSQLFormat) { + init(bytes: ByteBuffer?, dataType: PSQLDataType, format: PostgresFormat) { self.bytes = bytes self.dataType = dataType self.format = format diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 7af85fd3..545e1efb 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -32,7 +32,7 @@ extension PostgresData: PSQLEncodable { PSQLDataType(Int32(self.type.rawValue)) } - var psqlFormat: PSQLFormat { + var psqlFormat: PostgresFormat { .binary } @@ -53,7 +53,7 @@ extension PostgresData: PSQLEncodable { } extension PostgresData: PSQLDecodable { - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> PostgresData { + static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> PostgresData { let myBuffer = byteBuffer.readSlice(length: byteBuffer.readableBytes)! return PostgresData(type: PostgresDataType(UInt32(type.rawValue)), typeModifier: nil, formatCode: .binary, value: myBuffer) @@ -102,8 +102,8 @@ extension PSQLError { } } -extension PostgresFormatCode { - init(psqlFormatCode: PSQLFormat) { +extension PostgresFormat { + init(psqlFormatCode: PostgresFormat) { switch psqlFormatCode { case .binary: self = .binary diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 325641e8..57106393 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -38,7 +38,7 @@ class JSON_PSQLCodableTests: XCTestCase { } func testDecodeFromJSONAsText() { - let combinations : [(PSQLFormat, PSQLDataType)] = [ + let combinations : [(PostgresFormat, PSQLDataType)] = [ (.text, .json), (.text, .jsonb), ] var buffer = ByteBuffer() diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 8b1be81e..3abf035b 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -40,7 +40,7 @@ class UUID_PSQLCodableTests: XCTestCase { } func testDecodeFromString() { - let options: [(PSQLFormat, PSQLDataType)] = [ + let options: [(PostgresFormat, PSQLDataType)] = [ (.binary, .text), (.binary, .varchar), (.text, .uuid), diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index ea303f20..fce52d13 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -287,7 +287,7 @@ class PSQLRowStreamTests: XCTestCase { XCTAssertEqual(stream.commandTag, "SELECT 6") } - func makeColumnDescription(name: String, dataType: PSQLDataType, format: PSQLFormat) -> RowDescription.Column { + func makeColumnDescription(name: String, dataType: PSQLDataType, format: PostgresFormat) -> RowDescription.Column { RowDescription.Column( name: "test", tableOID: 123, From 0b5c40077d5053f4ccf5ba9f23902e4ac9133f7b Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 17 Feb 2022 19:53:17 +0100 Subject: [PATCH 047/292] Remove PSQLJSONDecoder (#216) --- Sources/PostgresNIO/New/PSQL+JSON.swift | 6 ------ Sources/PostgresNIO/New/PSQLCodable.swift | 4 ++-- Sources/PostgresNIO/New/PSQLRow.swift | 4 ++-- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 14 -------------- .../Utilities/PostgresJSONDecoder.swift | 14 +++++++++++++- .../New/Extensions/PSQLCoding+TestUtils.swift | 2 +- 6 files changed, 18 insertions(+), 26 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQL+JSON.swift b/Sources/PostgresNIO/New/PSQL+JSON.swift index 564a2cc1..4183d204 100644 --- a/Sources/PostgresNIO/New/PSQL+JSON.swift +++ b/Sources/PostgresNIO/New/PSQL+JSON.swift @@ -7,10 +7,4 @@ protocol PSQLJSONEncoder { func encode(_ value: T, into buffer: inout ByteBuffer) throws } -protocol PSQLJSONDecoder { - func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T -} - extension JSONEncoder: PSQLJSONEncoder {} -extension JSONDecoder: PSQLJSONDecoder {} - diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index c523eda8..143ce463 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -60,7 +60,7 @@ struct PSQLEncodingContext { struct PSQLDecodingContext { - let jsonDecoder: PSQLJSONDecoder + let jsonDecoder: PostgresJSONDecoder let columnIndex: Int let columnName: String @@ -68,7 +68,7 @@ struct PSQLDecodingContext { let file: String let line: Int - init(jsonDecoder: PSQLJSONDecoder, columnName: String, columnIndex: Int, file: String, line: Int) { + init(jsonDecoder: PostgresJSONDecoder, columnName: String, columnIndex: Int, file: String, line: Int) { self.jsonDecoder = jsonDecoder self.columnName = columnName self.columnIndex = columnIndex diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift index 9fbc7f14..f76f9eef 100644 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -29,7 +29,7 @@ extension PSQLRow { /// - type: The type to decode the data into /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. /// - Returns: The decoded value of Type T. - func decode(column: String, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { + func decode(column: String, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { guard let index = self.lookupTable[column] else { preconditionFailure("A column '\(column)' does not exist.") } @@ -44,7 +44,7 @@ extension PSQLRow { /// - type: The type to decode the data into /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. /// - Returns: The decoded value of Type T. - func decode(column index: Int, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { + func decode(column index: Int, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { precondition(index < self.data.columnCount) let column = self.columns[index] diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 545e1efb..bb540a8e 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -1,19 +1,5 @@ import NIOCore -struct PostgresJSONDecoderWrapper: PSQLJSONDecoder { - let downstream: PostgresJSONDecoder - - init(_ downstream: PostgresJSONDecoder) { - self.downstream = downstream - } - - func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T where T : Decodable { - var buffer = buffer - let data = buffer.readData(length: buffer.readableBytes)! - return try self.downstream.decode(T.self, from: data) - } -} - struct PostgresJSONEncoderWrapper: PSQLJSONEncoder { let downstream: PostgresJSONEncoder diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift index 78bdebb2..5a87a182 100644 --- a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift +++ b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift @@ -1,10 +1,22 @@ -import Foundation +import class Foundation.JSONDecoder +import struct Foundation.Data +import NIOFoundationCompat /// A protocol that mimicks the Foundation `JSONDecoder.decode(_:from:)` function. /// Conform a non-Foundation JSON decoder to this protocol if you want PostgresNIO to be /// able to use it when decoding JSON & JSONB values (see `PostgresNIO._defaultJSONDecoder`) public protocol PostgresJSONDecoder { func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable + + func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T +} + +extension PostgresJSONDecoder { + public func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T { + var copy = buffer + let data = copy.readData(length: buffer.readableBytes)! + return try self.decode(type, from: data) + } } extension JSONDecoder: PostgresJSONDecoder {} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift index b6f2e1d1..49d68057 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift @@ -8,7 +8,7 @@ extension PSQLFrontendMessageEncoder { } extension PSQLDecodingContext { - static func forTests(columnName: String = "unknown", columnIndex: Int = 0, jsonDecoder: PSQLJSONDecoder = JSONDecoder(), file: String = #file, line: Int = #line) -> Self { + static func forTests(columnName: String = "unknown", columnIndex: Int = 0, jsonDecoder: PostgresJSONDecoder = JSONDecoder(), file: String = #file, line: Int = #line) -> Self { Self(jsonDecoder: JSONDecoder(), columnName: columnName, columnIndex: columnIndex, file: file, line: line) } } From eaef2084327ac950352bbb2e7144a012178f5c80 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 17 Feb 2022 20:44:31 +0100 Subject: [PATCH 048/292] Merge PSQLDataType with PostgresDataType (#213) --- .../PostgresNIO/Data/PostgresDataType.swift | 21 ++- .../New/Data/Array+PSQLCodable.swift | 61 +++---- .../New/Data/Bool+PSQLCodable.swift | 6 +- .../New/Data/Bytes+PSQLCodable.swift | 12 +- .../New/Data/Date+PSQLCodable.swift | 4 +- .../New/Data/Decimal+PSQLCodable.swift | 14 +- .../New/Data/Float+PSQLCodable.swift | 8 +- .../New/Data/Int+PSQLCodable.swift | 56 +++--- .../New/Data/JSON+PSQLCodable.swift | 4 +- .../New/Data/Optional+PSQLCodable.swift | 4 +- .../Data/RawRepresentable+PSQLCodable.swift | 4 +- .../New/Data/String+PSQLCodable.swift | 4 +- .../New/Data/UUID+PSQLCodable.swift | 4 +- .../New/Messages/ParameterDescription.swift | 8 +- Sources/PostgresNIO/New/Messages/Parse.swift | 2 +- .../New/Messages/RowDescription.swift | 6 +- Sources/PostgresNIO/New/PSQLCodable.swift | 4 +- Sources/PostgresNIO/New/PSQLData.swift | 171 +----------------- Sources/PostgresNIO/New/PSQLError.swift | 6 +- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 8 +- .../New/Data/JSON+PSQLCodableTests.swift | 2 +- .../New/Data/String+PSQLCodableTests.swift | 6 +- .../New/Data/UUID+PSQLCodableTests.swift | 6 +- .../Messages/ParameterDescriptionTests.swift | 4 +- .../New/Messages/ParseTests.swift | 16 +- .../New/PSQLRowStreamTests.swift | 2 +- 26 files changed, 135 insertions(+), 308 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index 37520242..1652048b 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -26,7 +26,7 @@ public typealias PostgresFormatCode = PostgresFormat /// The data type's raw object ID. /// Use `select * from pg_type where oid = ;` to lookup more information. -public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, CustomStringConvertible, RawRepresentable { +public struct PostgresDataType: RawRepresentable, Equatable, CustomStringConvertible { /// `0` public static let null = PostgresDataType(0) /// `16` @@ -125,12 +125,7 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, public var isUserDefined: Bool { self.rawValue >= 1 << 14 } - - /// See `ExpressibleByIntegerLiteral.init(integerLiteral:)` - public init(integerLiteral value: UInt32) { - self.init(value) - } - + public init(_ rawValue: UInt32) { self.rawValue = rawValue } @@ -138,7 +133,7 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, public init?(rawValue: UInt32) { self.init(rawValue) } - + /// Returns the known SQL name, if one exists. /// Note: This only supports a limited subset of all PSQL types and is meant for convenience only. public var knownSQLName: String? { @@ -237,3 +232,13 @@ public struct PostgresDataType: Codable, Equatable, ExpressibleByIntegerLiteral, return self.knownSQLName ?? "UNKNOWN \(self.rawValue)" } } + +// TODO: The Codable conformance does not make any sense. Let's remove this with next major break. +extension PostgresDataType: Codable {} + +// TODO: The ExpressibleByIntegerLiteral conformance does not make any sense and is not used anywhere. Remove with next major break. +extension PostgresDataType: ExpressibleByIntegerLiteral { + public init(integerLiteral value: UInt32) { + self.init(value) + } +} diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index d9371f47..bad901dc 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -3,72 +3,72 @@ import struct Foundation.UUID /// A type, of which arrays can be encoded into and decoded from a postgres binary format protocol PSQLArrayElement: PSQLCodable { - static var psqlArrayType: PSQLDataType { get } - static var psqlArrayElementType: PSQLDataType { get } + static var psqlArrayType: PostgresDataType { get } + static var psqlArrayElementType: PostgresDataType { get } } extension Bool: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .boolArray } - static var psqlArrayElementType: PSQLDataType { .bool } + static var psqlArrayType: PostgresDataType { .boolArray } + static var psqlArrayElementType: PostgresDataType { .bool } } extension ByteBuffer: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .byteaArray } - static var psqlArrayElementType: PSQLDataType { .bytea } + static var psqlArrayType: PostgresDataType { .byteaArray } + static var psqlArrayElementType: PostgresDataType { .bytea } } extension UInt8: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .charArray } - static var psqlArrayElementType: PSQLDataType { .char } + static var psqlArrayType: PostgresDataType { .charArray } + static var psqlArrayElementType: PostgresDataType { .char } } extension Int16: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .int2Array } - static var psqlArrayElementType: PSQLDataType { .int2 } + static var psqlArrayType: PostgresDataType { .int2Array } + static var psqlArrayElementType: PostgresDataType { .int2 } } extension Int32: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .int4Array } - static var psqlArrayElementType: PSQLDataType { .int4 } + static var psqlArrayType: PostgresDataType { .int4Array } + static var psqlArrayElementType: PostgresDataType { .int4 } } extension Int64: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .int8Array } - static var psqlArrayElementType: PSQLDataType { .int8 } + static var psqlArrayType: PostgresDataType { .int8Array } + static var psqlArrayElementType: PostgresDataType { .int8 } } extension Int: PSQLArrayElement { #if (arch(i386) || arch(arm)) - static var psqlArrayType: PSQLDataType { .int4Array } - static var psqlArrayElementType: PSQLDataType { .int4 } + static var psqlArrayType: PostgresDataType { .int4Array } + static var psqlArrayElementType: PostgresDataType { .int4 } #else - static var psqlArrayType: PSQLDataType { .int8Array } - static var psqlArrayElementType: PSQLDataType { .int8 } + static var psqlArrayType: PostgresDataType { .int8Array } + static var psqlArrayElementType: PostgresDataType { .int8 } #endif } extension Float: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .float4Array } - static var psqlArrayElementType: PSQLDataType { .float4 } + static var psqlArrayType: PostgresDataType { .float4Array } + static var psqlArrayElementType: PostgresDataType { .float4 } } extension Double: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .float8Array } - static var psqlArrayElementType: PSQLDataType { .float8 } + static var psqlArrayType: PostgresDataType { .float8Array } + static var psqlArrayElementType: PostgresDataType { .float8 } } extension String: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .textArray } - static var psqlArrayElementType: PSQLDataType { .text } + static var psqlArrayType: PostgresDataType { .textArray } + static var psqlArrayElementType: PostgresDataType { .text } } extension UUID: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .uuidArray } - static var psqlArrayElementType: PSQLDataType { .uuid } + static var psqlArrayType: PostgresDataType { .uuidArray } + static var psqlArrayElementType: PostgresDataType { .uuid } } extension Array: PSQLEncodable where Element: PSQLArrayElement { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { Element.psqlArrayType } @@ -101,20 +101,19 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { } extension Array: PSQLDecodable where Element: PSQLArrayElement { - - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Array { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Array { guard case .binary = format else { // currently we only support decoding arrays in binary format. throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - guard let (isNotEmpty, b, element) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32, Int32).self), + guard let (isNotEmpty, b, element) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32, UInt32).self), 0 <= isNotEmpty, isNotEmpty <= 1, b == 0 else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - let elementType = PSQLDataType(rawValue: element) + let elementType = PostgresDataType(element) guard isNotEmpty == 1 else { return [] diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift index 5e097ac3..4bd4bb33 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift @@ -1,15 +1,15 @@ import NIOCore extension Bool: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .bool } var psqlFormat: PostgresFormat { .binary } - - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Bool { + + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Bool { guard type == .bool else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift index 22298026..b359f3ca 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift @@ -3,7 +3,7 @@ import NIOCore import NIOFoundationCompat extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .bytea } @@ -17,7 +17,7 @@ extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { } extension ByteBuffer: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .bytea } @@ -29,14 +29,14 @@ extension ByteBuffer: PSQLCodable { var copyOfSelf = self // dirty hack byteBuffer.writeBuffer(©OfSelf) } - - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { return buffer } } extension Data: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .bytea } @@ -48,7 +48,7 @@ extension Data: PSQLCodable { byteBuffer.writeBytes(self) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { return buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! } } diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift index 7639cd66..868a0929 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift @@ -2,7 +2,7 @@ import NIOCore import struct Foundation.Date extension Date: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .timestamptz } @@ -10,7 +10,7 @@ extension Date: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch type { case .timestamp, .timestamptz: guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { diff --git a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift index d36f5b57..990b9ebf 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift @@ -2,7 +2,7 @@ import NIOCore import struct Foundation.Decimal extension Decimal: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .numeric } @@ -10,20 +10,20 @@ extension Decimal: PSQLCodable { .binary } - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Decimal { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .numeric): - guard let numeric = PostgresNumeric(buffer: &byteBuffer) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context) + guard let numeric = PostgresNumeric(buffer: &buffer) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return numeric.decimal case (.text, .numeric): - guard let string = byteBuffer.readString(length: byteBuffer.readableBytes), let value = Decimal(string: string) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context) + guard let string = buffer.readString(length: buffer.readableBytes), let value = Decimal(string: string) else { + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context) + throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } } diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index d4560dc3..738160eb 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -1,7 +1,7 @@ import NIOCore extension Float: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .float4 } @@ -9,7 +9,7 @@ extension Float: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Float { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { @@ -37,7 +37,7 @@ extension Float: PSQLCodable { } extension Double: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .float8 } @@ -45,7 +45,7 @@ extension Double: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Double { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift index abd5d19d..41c411c3 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift @@ -1,16 +1,15 @@ import NIOCore extension UInt8: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .char } - + var psqlFormat: PostgresFormat { .binary } - - // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch type { case .bpchar, .char: guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { @@ -22,8 +21,7 @@ extension UInt8: PSQLCodable { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } } - - // encoding + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: UInt8.self) } @@ -31,16 +29,15 @@ extension UInt8: PSQLCodable { extension Int16: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .int2 } - + var psqlFormat: PostgresFormat { .binary } - - // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -56,24 +53,22 @@ extension Int16: PSQLCodable { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } } - - // encoding + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: Int16.self) } } extension Int32: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .int4 } var psqlFormat: PostgresFormat { .binary } - - // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -94,24 +89,22 @@ extension Int32: PSQLCodable { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } } - - // encoding + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: Int32.self) } } extension Int64: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .int8 } - + var psqlFormat: PostgresFormat { .binary } - - // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -137,15 +130,14 @@ extension Int64: PSQLCodable { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } } - - // encoding + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: Int64.self) } } extension Int: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { switch self.bitWidth { case Int32.bitWidth: return .int4 @@ -159,9 +151,8 @@ extension Int: PSQLCodable { var psqlFormat: PostgresFormat { .binary } - - // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -187,8 +178,7 @@ extension Int: PSQLCodable { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } } - - // encoding + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: Int.self) } diff --git a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift index 3f9b1093..7dc9348d 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift @@ -6,7 +6,7 @@ import class Foundation.JSONDecoder private let JSONBVersionByte: UInt8 = 0x01 extension PSQLCodable where Self: Codable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .jsonb } @@ -14,7 +14,7 @@ extension PSQLCodable where Self: Codable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .jsonb): guard JSONBVersionByte == buffer.readInteger(as: UInt8.self) else { diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index fa19df26..53aa0f3a 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -1,7 +1,7 @@ import NIOCore extension Optional: PSQLDecodable where Wrapped: PSQLDecodable { - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Optional { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { preconditionFailure("This code path should never be hit.") // The code path for decoding an optional should be: // -> PSQLData.decode(as: String?.self) @@ -11,7 +11,7 @@ extension Optional: PSQLDecodable where Wrapped: PSQLDecodable { } extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { switch self { case .some(let value): return value.psqlType diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift index 367fa45a..706f58d3 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift @@ -1,7 +1,7 @@ import NIOCore extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { self.rawValue.psqlType } @@ -9,7 +9,7 @@ extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { self.rawValue.psqlFormat } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { guard let rawValue = try? RawValue.decode(from: &buffer, type: type, format: format, context: context), let selfValue = Self.init(rawValue: rawValue) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift index 970f7e48..ca59f0e2 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift @@ -2,7 +2,7 @@ import NIOCore import struct Foundation.UUID extension String: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .text } @@ -14,7 +14,7 @@ extension String: PSQLCodable { byteBuffer.writeString(self) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> String { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (_, .varchar), (_, .text), diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift index eef54983..f7e738c2 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift @@ -4,7 +4,7 @@ import typealias Foundation.uuid_t extension UUID: PSQLCodable { - var psqlType: PSQLDataType { + var psqlType: PostgresDataType { .uuid } @@ -22,7 +22,7 @@ extension UUID: PSQLCodable { ]) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> UUID { + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .uuid): guard let uuid = buffer.readUUID() else { diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index 971b3ac7..bd468c44 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -4,7 +4,7 @@ extension PSQLBackendMessage { struct ParameterDescription: PayloadDecodable, Equatable { /// Specifies the object ID of the parameter data type. - var dataTypes: [PSQLDataType] + var dataTypes: [PostgresDataType] static func decode(from buffer: inout ByteBuffer) throws -> Self { let parameterCount = try buffer.throwingReadInteger(as: Int16.self) @@ -12,12 +12,12 @@ extension PSQLBackendMessage { throw PSQLPartialDecodingError.integerMustBePositiveOrNull(parameterCount) } - var result = [PSQLDataType]() + var result = [PostgresDataType]() result.reserveCapacity(Int(parameterCount)) for _ in 0.. Self + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self } /// A type that can be encoded into and decoded from a postgres binary format diff --git a/Sources/PostgresNIO/New/PSQLData.swift b/Sources/PostgresNIO/New/PSQLData.swift index 4d1c3acc..9131834e 100644 --- a/Sources/PostgresNIO/New/PSQLData.swift +++ b/Sources/PostgresNIO/New/PSQLData.swift @@ -3,11 +3,11 @@ import NIOCore struct PSQLData: Equatable { @usableFromInline var bytes: ByteBuffer? - @usableFromInline var dataType: PSQLDataType + @usableFromInline var dataType: PostgresDataType @usableFromInline var format: PostgresFormat /// use this only for testing - init(bytes: ByteBuffer?, dataType: PSQLDataType, format: PostgresFormat) { + init(bytes: ByteBuffer?, dataType: PostgresDataType, format: PostgresFormat) { self.bytes = bytes self.dataType = dataType self.format = format @@ -38,170 +38,3 @@ struct PSQLData: Equatable { } } } - -struct PSQLDataType: RawRepresentable, Equatable, CustomStringConvertible { - typealias RawValue = Int32 - - /// The raw data type code recognized by PostgreSQL. - var rawValue: Int32 - - /// `0` - static let null = PSQLDataType(0) - /// `16` - static let bool = PSQLDataType(16) - /// `17` - static let bytea = PSQLDataType(17) - /// `18` - static let char = PSQLDataType(18) - /// `19` - static let name = PSQLDataType(19) - /// `20` - static let int8 = PSQLDataType(20) - /// `21` - static let int2 = PSQLDataType(21) - /// `23` - static let int4 = PSQLDataType(23) - /// `24` - static let regproc = PSQLDataType(24) - /// `25` - static let text = PSQLDataType(25) - /// `26` - static let oid = PSQLDataType(26) - /// `114` - static let json = PSQLDataType(114) - /// `194` pg_node_tree - static let pgNodeTree = PSQLDataType(194) - /// `600` - static let point = PSQLDataType(600) - /// `700` - static let float4 = PSQLDataType(700) - /// `701` - static let float8 = PSQLDataType(701) - /// `790` - static let money = PSQLDataType(790) - /// `1000` _bool - static let boolArray = PSQLDataType(1000) - /// `1001` _bytea - static let byteaArray = PSQLDataType(1001) - /// `1002` _char - static let charArray = PSQLDataType(1002) - /// `1003` _name - static let nameArray = PSQLDataType(1003) - /// `1005` _int2 - static let int2Array = PSQLDataType(1005) - /// `1007` _int4 - static let int4Array = PSQLDataType(1007) - /// `1009` _text - static let textArray = PSQLDataType(1009) - /// `1015` _varchar - static let varcharArray = PSQLDataType(1015) - /// `1016` _int8 - static let int8Array = PSQLDataType(1016) - /// `1017` _point - static let pointArray = PSQLDataType(1017) - /// `1021` _float4 - static let float4Array = PSQLDataType(1021) - /// `1022` _float8 - static let float8Array = PSQLDataType(1022) - /// `1034` _aclitem - static let aclitemArray = PSQLDataType(1034) - /// `1042` - static let bpchar = PSQLDataType(1042) - /// `1043` - static let varchar = PSQLDataType(1043) - /// `1082` - static let date = PSQLDataType(1082) - /// `1083` - static let time = PSQLDataType(1083) - /// `1114` - static let timestamp = PSQLDataType(1114) - /// `1115` _timestamp - static let timestampArray = PSQLDataType(1115) - /// `1184` - static let timestamptz = PSQLDataType(1184) - /// `1266` - static let timetz = PSQLDataType(1266) - /// `1700` - static let numeric = PSQLDataType(1700) - /// `2278` - static let void = PSQLDataType(2278) - /// `2950` - static let uuid = PSQLDataType(2950) - /// `2951` _uuid - static let uuidArray = PSQLDataType(2951) - /// `3802` - static let jsonb = PSQLDataType(3802) - /// `3807` _jsonb - static let jsonbArray = PSQLDataType(3807) - - /// Returns `true` if the type's raw value is greater than `2^14`. - /// This _appears_ to be true for all user-defined types, but I don't - /// have any documentation to back this up. - var isUserDefined: Bool { - self.rawValue >= 1 << 14 - } - - init(_ rawValue: Int32) { - self.rawValue = rawValue - } - - init(rawValue: Int32) { - self.init(rawValue) - } - - /// Returns the known SQL name, if one exists. - /// Note: This only supports a limited subset of all PSQL types and is meant for convenience only. - var knownSQLName: String? { - switch self { - case .bool: return "BOOLEAN" - case .bytea: return "BYTEA" - case .char: return "CHAR" - case .name: return "NAME" - case .int8: return "BIGINT" - case .int2: return "SMALLINT" - case .int4: return "INTEGER" - case .regproc: return "REGPROC" - case .text: return "TEXT" - case .oid: return "OID" - case .json: return "JSON" - case .pgNodeTree: return "PGNODETREE" - case .point: return "POINT" - case .float4: return "REAL" - case .float8: return "DOUBLE PRECISION" - case .money: return "MONEY" - case .boolArray: return "BOOLEAN[]" - case .byteaArray: return "BYTEA[]" - case .charArray: return "CHAR[]" - case .nameArray: return "NAME[]" - case .int2Array: return "SMALLINT[]" - case .int4Array: return "INTEGER[]" - case .textArray: return "TEXT[]" - case .varcharArray: return "VARCHAR[]" - case .int8Array: return "BIGINT[]" - case .pointArray: return "POINT[]" - case .float4Array: return "REAL[]" - case .float8Array: return "DOUBLE PRECISION[]" - case .aclitemArray: return "ACLITEM[]" - case .bpchar: return "BPCHAR" - case .varchar: return "VARCHAR" - case .date: return "DATE" - case .time: return "TIME" - case .timestamp: return "TIMESTAMP" - case .timestamptz: return "TIMESTAMPTZ" - case .timestampArray: return "TIMESTAMP[]" - case .numeric: return "NUMERIC" - case .void: return "VOID" - case .uuid: return "UUID" - case .uuidArray: return "UUID[]" - case .jsonb: return "JSONB" - case .jsonbArray: return "JSONB[]" - default: return nil - } - } - - /// See `CustomStringConvertible`. - var description: String { - return self.knownSQLName ?? "UNKNOWN \(self.rawValue)" - } -} - diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 0cadc9ee..bc642e6d 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -89,13 +89,13 @@ struct PSQLCastingError: Error { let line: Int let targetType: PSQLDecodable.Type - let postgresType: PSQLDataType + let postgresType: PostgresDataType let postgresData: ByteBuffer? let description: String let underlying: Error? - static func missingData(targetType: PSQLDecodable.Type, type: PSQLDataType, context: PSQLDecodingContext) -> Self { + static func missingData(targetType: PSQLDecodable.Type, type: PostgresDataType, context: PSQLDecodingContext) -> Self { PSQLCastingError( columnName: context.columnName, columnIndex: context.columnIndex, @@ -113,7 +113,7 @@ struct PSQLCastingError: Error { } static func failure(targetType: PSQLDecodable.Type, - type: PSQLDataType, + type: PostgresDataType, postgresData: ByteBuffer, description: String? = nil, underlying: Error? = nil, diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index bb540a8e..d28a9f0f 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -14,8 +14,8 @@ struct PostgresJSONEncoderWrapper: PSQLJSONEncoder { } extension PostgresData: PSQLEncodable { - var psqlType: PSQLDataType { - PSQLDataType(Int32(self.type.rawValue)) + var psqlType: PostgresDataType { + self.type } var psqlFormat: PostgresFormat { @@ -39,8 +39,8 @@ extension PostgresData: PSQLEncodable { } extension PostgresData: PSQLDecodable { - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> PostgresData { - let myBuffer = byteBuffer.readSlice(length: byteBuffer.readableBytes)! + static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + let myBuffer = buffer.readSlice(length: buffer.readableBytes)! return PostgresData(type: PostgresDataType(UInt32(type.rawValue)), typeModifier: nil, formatCode: .binary, value: myBuffer) } diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 57106393..b6041f6c 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -38,7 +38,7 @@ class JSON_PSQLCodableTests: XCTestCase { } func testDecodeFromJSONAsText() { - let combinations : [(PostgresFormat, PSQLDataType)] = [ + let combinations : [(PostgresFormat, PostgresDataType)] = [ (.text, .json), (.text, .jsonb), ] var buffer = ByteBuffer() diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 304bb7d6..f9d5b03d 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -19,7 +19,7 @@ class String_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeString(expected) - let dataTypes: [PSQLDataType] = [ + let dataTypes: [PostgresDataType] = [ .text, .varchar, .name ] @@ -33,7 +33,7 @@ class String_PSQLCodableTests: XCTestCase { func testDecodeFailureFromInvalidType() { let buffer = ByteBuffer() - let dataTypes: [PSQLDataType] = [.bool, .float4Array, .float8Array, .bpchar] + let dataTypes: [PostgresDataType] = [.bool, .float4Array, .float8Array, .bpchar] for dataType in dataTypes { var loopBuffer = buffer @@ -48,7 +48,7 @@ class String_PSQLCodableTests: XCTestCase { } func testDecodeFailureFromNoData() { - let dataTypes: [PSQLDataType] = [.text, .varchar, .name] + let dataTypes: [PostgresDataType] = [.text, .varchar, .name] for dataType in dataTypes { let data = PSQLData(bytes: nil, dataType: dataType, format: .binary) diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 3abf035b..9c639d98 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -40,7 +40,7 @@ class UUID_PSQLCodableTests: XCTestCase { } func testDecodeFromString() { - let options: [(PostgresFormat, PSQLDataType)] = [ + let options: [(PostgresFormat, PostgresDataType)] = [ (.binary, .text), (.binary, .varchar), (.text, .uuid), @@ -98,7 +98,7 @@ class UUID_PSQLCodableTests: XCTestCase { // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - let dataTypes: [PSQLDataType] = [.varchar, .text] + let dataTypes: [PostgresDataType] = [.varchar, .text] for dataType in dataTypes { var loopBuffer = buffer @@ -117,7 +117,7 @@ class UUID_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeString(uuid.uuidString) - let dataTypes: [PSQLDataType] = [.bool, .int8, .int2, .int4Array] + let dataTypes: [PostgresDataType] = [.bool, .int8, .int2, .int4Array] for dataType in dataTypes { let data = PSQLData(bytes: buffer, dataType: dataType, format: .binary) diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift index ebc80a8e..8bbdae4c 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift @@ -31,7 +31,7 @@ class ParameterDescriptionTests: XCTestCase { } func testDecodeWithNegativeCount() { - let dataTypes: [PSQLDataType] = [.bool, .varchar, .uuid, .json, .jsonbArray] + let dataTypes: [PostgresDataType] = [.bool, .varchar, .uuid, .json, .jsonbArray] var buffer = ByteBuffer() buffer.writeBackendMessage(id: .parameterDescription) { buffer in buffer.writeInteger(Int16(-4)) @@ -49,7 +49,7 @@ class ParameterDescriptionTests: XCTestCase { } func testDecodeColumnCountDoesntMatchMessageLength() { - let dataTypes: [PSQLDataType] = [.bool, .varchar, .uuid, .json, .jsonbArray] + let dataTypes: [PostgresDataType] = [.bool, .varchar, .uuid, .json, .jsonbArray] var buffer = ByteBuffer() buffer.writeBackendMessage(id: .parameterDescription) { buffer in // means three columns comming, but 5 are in the buffer actually. diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index c147b749..edf3f48d 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -27,14 +27,14 @@ class ParseTests: XCTestCase { XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName) XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query) XCTAssertEqual(byteBuffer.readInteger(as: Int16.self), Int16(parse.parameters.count)) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.bool.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.int8.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.bytea.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.varchar.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.text.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.uuid.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.json.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), PSQLDataType.jsonbArray.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bool.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.int8.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bytea.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.varchar.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.text.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.uuid.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.json.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.jsonbArray.rawValue) } } diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index fce52d13..abbfce14 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -287,7 +287,7 @@ class PSQLRowStreamTests: XCTestCase { XCTAssertEqual(stream.commandTag, "SELECT 6") } - func makeColumnDescription(name: String, dataType: PSQLDataType, format: PostgresFormat) -> RowDescription.Column { + func makeColumnDescription(name: String, dataType: PostgresDataType, format: PostgresFormat) -> RowDescription.Column { RowDescription.Column( name: "test", tableOID: 123, From 7660f79510132fc1b91f21916847eda411b23ca2 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 17 Feb 2022 21:29:21 +0100 Subject: [PATCH 049/292] Remove PSQLJSONEncoder (#215) --- .../Connection/PostgresConnection+Connect.swift | 2 +- Sources/PostgresNIO/New/Messages/Bind.swift | 2 +- Sources/PostgresNIO/New/PSQL+JSON.swift | 10 ---------- Sources/PostgresNIO/New/PSQLCodable.swift | 2 +- Sources/PostgresNIO/New/PSQLConnection.swift | 4 ++-- .../New/PSQLFrontendMessageEncoder.swift | 4 ++-- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 13 ------------- .../PostgresNIO/Utilities/PostgresJSONEncoder.swift | 10 ++++++++++ .../New/Data/JSON+PSQLCodableTests.swift | 6 +++++- .../New/Extensions/PSQLCoding+TestUtils.swift | 2 +- 10 files changed, 23 insertions(+), 32 deletions(-) delete mode 100644 Sources/PostgresNIO/New/PSQL+JSON.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift index 518e9234..388cdbc4 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift @@ -12,7 +12,7 @@ extension PostgresConnection { ) -> EventLoopFuture { let coders = PSQLConnection.Configuration.Coders( - jsonEncoder: PostgresJSONEncoderWrapper(_defaultJSONEncoder) + jsonEncoder: _defaultJSONEncoder ) let configuration = PSQLConnection.Configuration( diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index 500a13b9..eea976c9 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -12,7 +12,7 @@ extension PSQLFrontendMessage { /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. var parameters: [PSQLEncodable] - func encode(into buffer: inout ByteBuffer, using jsonEncoder: PSQLJSONEncoder) throws { + func encode(into buffer: inout ByteBuffer, using jsonEncoder: PostgresJSONEncoder) throws { buffer.writeNullTerminatedString(self.portalName) buffer.writeNullTerminatedString(self.preparedStatementName) diff --git a/Sources/PostgresNIO/New/PSQL+JSON.swift b/Sources/PostgresNIO/New/PSQL+JSON.swift deleted file mode 100644 index 4183d204..00000000 --- a/Sources/PostgresNIO/New/PSQL+JSON.swift +++ /dev/null @@ -1,10 +0,0 @@ -import NIOCore -import NIOFoundationCompat -import class Foundation.JSONEncoder -import class Foundation.JSONDecoder - -protocol PSQLJSONEncoder { - func encode(_ value: T, into buffer: inout ByteBuffer) throws -} - -extension JSONEncoder: PSQLJSONEncoder {} diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index 9b84bca0..fa45ea2a 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -55,7 +55,7 @@ extension PSQLEncodable { } struct PSQLEncodingContext { - let jsonEncoder: PSQLJSONEncoder + let jsonEncoder: PostgresJSONEncoder } struct PSQLDecodingContext { diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 4f5d3f64..40b42b11 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -13,9 +13,9 @@ final class PSQLConnection { struct Configuration { struct Coders { - var jsonEncoder: PSQLJSONEncoder + var jsonEncoder: PostgresJSONEncoder - init(jsonEncoder: PSQLJSONEncoder) { + init(jsonEncoder: PostgresJSONEncoder) { self.jsonEncoder = jsonEncoder } diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift index 227cd233..ea016970 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift @@ -2,9 +2,9 @@ struct PSQLFrontendMessageEncoder: MessageToByteEncoder { typealias OutboundIn = PSQLFrontendMessage - let jsonEncoder: PSQLJSONEncoder + let jsonEncoder: PostgresJSONEncoder - init(jsonEncoder: PSQLJSONEncoder) { + init(jsonEncoder: PostgresJSONEncoder) { self.jsonEncoder = jsonEncoder } diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index d28a9f0f..c0f7cef8 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -1,18 +1,5 @@ import NIOCore -struct PostgresJSONEncoderWrapper: PSQLJSONEncoder { - let downstream: PostgresJSONEncoder - - init(_ downstream: PostgresJSONEncoder) { - self.downstream = downstream - } - - func encode(_ value: T, into buffer: inout ByteBuffer) throws where T : Encodable { - let data = try self.downstream.encode(value) - buffer.writeData(data) - } -} - extension PostgresData: PSQLEncodable { var psqlType: PostgresDataType { self.type diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift index 9730a061..3cabcf1d 100644 --- a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift +++ b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift @@ -1,10 +1,20 @@ import Foundation +import NIOFoundationCompat /// A protocol that mimicks the Foundation `JSONEncoder.encode(_:)` function. /// Conform a non-Foundation JSON encoder to this protocol if you want PostgresNIO to be /// able to use it when encoding JSON & JSONB values (see `PostgresNIO._defaultJSONEncoder`) public protocol PostgresJSONEncoder { func encode(_ value: T) throws -> Data where T : Encodable + + func encode(_ value: T, into buffer: inout ByteBuffer) throws +} + +extension PostgresJSONEncoder { + public func encode(_ value: T, into buffer: inout ByteBuffer) throws { + let data = try self.encode(value) + buffer.writeData(data) + } } extension JSONEncoder: PostgresJSONEncoder {} diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index b6041f6c..c9180016 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -73,12 +73,16 @@ class JSON_PSQLCodableTests: XCTestCase { } func testCustomEncoderIsUsed() { - class TestEncoder: PSQLJSONEncoder { + class TestEncoder: PostgresJSONEncoder { var encodeHits = 0 func encode(_ value: T, into buffer: inout ByteBuffer) throws where T : Encodable { self.encodeHits += 1 } + + func encode(_ value: T) throws -> Data where T : Encodable { + preconditionFailure() + } } let hello = Hello(name: "world") diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift index 49d68057..602306d8 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift @@ -14,7 +14,7 @@ extension PSQLDecodingContext { } extension PSQLEncodingContext { - static func forTests(jsonEncoder: PSQLJSONEncoder = JSONEncoder()) -> Self { + static func forTests(jsonEncoder: PostgresJSONEncoder = JSONEncoder()) -> Self { Self(jsonEncoder: jsonEncoder) } } From e61d43c8d32065000ae5620d81bb69c1c9f97a79 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 17 Feb 2022 22:12:30 +0100 Subject: [PATCH 050/292] PSQLDecodingError improvements (#211) --- .../New/Data/Array+PSQLCodable.swift | 10 +- .../New/Data/Bool+PSQLCodable.swift | 10 +- .../New/Data/Date+PSQLCodable.swift | 6 +- .../New/Data/Decimal+PSQLCodable.swift | 6 +- .../New/Data/Float+PSQLCodable.swift | 16 ++-- .../New/Data/Int+PSQLCodable.swift | 38 ++++---- .../New/Data/JSON+PSQLCodable.swift | 4 +- .../New/Data/Optional+PSQLCodable.swift | 33 +++++-- .../Data/RawRepresentable+PSQLCodable.swift | 2 +- .../New/Data/String+PSQLCodable.swift | 4 +- .../New/Data/UUID+PSQLCodable.swift | 8 +- Sources/PostgresNIO/New/PSQLCodable.swift | 23 ++++- Sources/PostgresNIO/New/PSQLData.swift | 25 ----- Sources/PostgresNIO/New/PSQLError.swift | 96 +++++++++---------- Sources/PostgresNIO/New/PSQLRow.swift | 7 +- .../New/Data/Array+PSQLCodableTests.swift | 45 ++++----- .../New/Data/Bool+PSQLCodableTests.swift | 39 ++++---- .../New/Data/Bytes+PSQLCodableTests.swift | 6 +- .../New/Data/Date+PSQLCodableTests.swift | 45 ++++----- .../New/Data/Decimal+PSQLCodableTests.swift | 10 +- .../New/Data/Float+PSQLCodableTests.swift | 76 +++++++-------- .../New/Data/JSON+PSQLCodableTests.swift | 28 +++--- .../New/Data/Optional+PSQLCodableTests.swift | 36 +++---- .../RawRepresentable+PSQLCodableTests.swift | 23 ++--- .../New/Data/String+PSQLCodableTests.swift | 31 +----- .../New/Data/UUID+PSQLCodableTests.swift | 25 ++--- .../PostgresNIOTests/New/PSQLDataTests.swift | 18 ---- .../New/PostgresErrorTests.swift | 28 ++++++ 28 files changed, 320 insertions(+), 378 deletions(-) delete mode 100644 Tests/PostgresNIOTests/New/PSQLDataTests.swift create mode 100644 Tests/PostgresNIOTests/New/PostgresErrorTests.swift diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index bad901dc..ba89bbb8 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -104,13 +104,13 @@ extension Array: PSQLDecodable where Element: PSQLArrayElement { static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Array { guard case .binary = format else { // currently we only support decoding arrays in binary format. - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } guard let (isNotEmpty, b, element) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32, UInt32).self), 0 <= isNotEmpty, isNotEmpty <= 1, b == 0 else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } let elementType = PostgresDataType(element) @@ -123,7 +123,7 @@ extension Array: PSQLDecodable where Element: PSQLArrayElement { expectedArrayCount > 0, dimensions == 1 else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } var result = Array() @@ -131,11 +131,11 @@ extension Array: PSQLDecodable where Element: PSQLArrayElement { for _ in 0 ..< expectedArrayCount { guard let elementLength = buffer.readInteger(as: Int32.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } guard var elementBuffer = buffer.readSlice(length: numericCast(elementLength)) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } let element = try Element.decode(from: &elementBuffer, type: elementType, format: format, context: context) diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift index 4bd4bb33..3d7a6776 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift @@ -11,13 +11,13 @@ extension Bool: PSQLCodable { static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Bool { guard type == .bool else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } switch format { case .binary: guard buffer.readableBytes == 1 else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } switch buffer.readInteger(as: UInt8.self) { @@ -26,11 +26,11 @@ extension Bool: PSQLCodable { case .some(1): return true default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } case .text: guard buffer.readableBytes == 1 else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } switch buffer.readInteger(as: UInt8.self) { @@ -39,7 +39,7 @@ extension Bool: PSQLCodable { case .some(UInt8(ascii: "t")): return true default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } } } diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift index 868a0929..71201853 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift @@ -14,18 +14,18 @@ extension Date: PSQLCodable { switch type { case .timestamp, .timestamptz: guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } let seconds = Double(microseconds) / Double(_microsecondsPerSecond) return Date(timeInterval: seconds, since: _psqlDateStart) case .date: guard buffer.readableBytes == 4, let days = buffer.readInteger(as: Int32.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } let seconds = Int64(days) * _secondsInDay return Date(timeInterval: Double(seconds), since: _psqlDateStart) default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } diff --git a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift index 990b9ebf..0a683e37 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift @@ -14,16 +14,16 @@ extension Decimal: PSQLCodable { switch (format, type) { case (.binary, .numeric): guard let numeric = PostgresNumeric(buffer: &buffer) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return numeric.decimal case (.text, .numeric): guard let string = buffer.readString(length: buffer.readableBytes), let value = Decimal(string: string) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index 738160eb..0aab376f 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -13,21 +13,21 @@ extension Float: PSQLCodable { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return float case (.binary, .float8): guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return Float(double) case (.text, .float4), (.text, .float8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Float(string) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } @@ -49,21 +49,21 @@ extension Double: PSQLCodable { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return Double(float) case (.binary, .float8): guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return double case (.text, .float4), (.text, .float8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Double(string) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift index 41c411c3..d63bb8eb 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift @@ -13,12 +13,12 @@ extension UInt8: PSQLCodable { switch type { case .bpchar, .char: guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } @@ -41,16 +41,16 @@ extension Int16: PSQLCodable { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value case (.text, .int2): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int16(string) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } @@ -72,21 +72,21 @@ extension Int32: PSQLCodable { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return Int32(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return Int32(value) case (.text, .int2), (.text, .int4): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int32(string) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } @@ -108,26 +108,26 @@ extension Int64: PSQLCodable { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return Int64(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return Int64(value) case (.binary, .int8): guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int64.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value case (.text, .int2), (.text, .int4), (.text, .int8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int64(string) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } @@ -156,26 +156,26 @@ extension Int: PSQLCodable { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return Int(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return Int(value) case (.binary, .int8) where Int.bitWidth == 64: guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value case (.text, .int2), (.text, .int4), (.text, .int8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int(string) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return value default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } diff --git a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift index 7dc9348d..1500ce84 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift @@ -18,13 +18,13 @@ extension PSQLCodable where Self: Codable { switch (format, type) { case (.binary, .jsonb): guard JSONBVersionByte == buffer.readInteger(as: UInt8.self) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return try context.jsonDecoder.decode(Self.self, from: buffer) case (.binary, .json), (.text, .jsonb), (.text, .json): return try context.jsonDecoder.decode(Self.self, from: buffer) default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index 53aa0f3a..a01d5f15 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -1,12 +1,29 @@ import NIOCore -extension Optional: PSQLDecodable where Wrapped: PSQLDecodable { - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { - preconditionFailure("This code path should never be hit.") - // The code path for decoding an optional should be: - // -> PSQLData.decode(as: String?.self) - // -> PSQLData.decodeIfPresent(String.self) - // -> String.decode(from: type:) +extension Optional: PSQLDecodable where Wrapped: PSQLDecodable, Wrapped.DecodableType == Wrapped { + typealias DecodableType = Wrapped + + static func decode( + from byteBuffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PSQLDecodingContext + ) throws -> Optional { + preconditionFailure("This should not be called") + } + + static func decodeRaw( + from byteBuffer: inout ByteBuffer?, + type: PostgresDataType, + format: PostgresFormat, + context: PSQLDecodingContext + ) throws -> Self { + switch byteBuffer { + case .some(var buffer): + return try DecodableType.decode(from: &buffer, type: type, format: format, context: context) + case .none: + return nil + } } } @@ -43,6 +60,6 @@ extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { } } -extension Optional: PSQLCodable where Wrapped: PSQLCodable { +extension Optional: PSQLCodable where Wrapped: PSQLCodable, Wrapped.DecodableType == Wrapped { } diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift index 706f58d3..f8812da3 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift @@ -12,7 +12,7 @@ extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { guard let rawValue = try? RawValue.decode(from: &buffer, type: type, format: format, context: context), let selfValue = Self.init(rawValue: rawValue) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return selfValue diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift index ca59f0e2..d761fc48 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift @@ -24,11 +24,11 @@ extension String: PSQLCodable { return buffer.readString(length: buffer.readableBytes)! case (_, .uuid): guard let uuid = try? UUID.decode(from: &buffer, type: .uuid, format: format, context: context) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return uuid.uuidString default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } } diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift index f7e738c2..0fdd2990 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift @@ -26,7 +26,7 @@ extension UUID: PSQLCodable { switch (format, type) { case (.binary, .uuid): guard let uuid = buffer.readUUID() else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return uuid case (.binary, .varchar), @@ -35,15 +35,15 @@ extension UUID: PSQLCodable { (.text, .text), (.text, .varchar): guard buffer.readableBytes == 36 else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } guard let uuid = buffer.readString(length: 36).flatMap({ UUID(uuidString: $0) }) else { - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.failure } return uuid default: - throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) + throw PostgresCastingError.Code.typeMismatch } } } diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index fa45ea2a..fbf3fbbb 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -20,6 +20,7 @@ protocol PSQLEncodable { /// A type that can decode itself from a postgres wire binary representation. protocol PSQLDecodable { + associatedtype DecodableType: PSQLDecodable = Self /// Decode an entity from the `byteBuffer` in postgres wire format /// @@ -32,7 +33,27 @@ protocol PSQLDecodable { /// - context: A `PSQLDecodingContext` providing context for decoding. This includes a `JSONDecoder` /// to use when decoding json and metadata to create better errors. /// - Returns: A decoded object - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self + static func decode(from byteBuffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self + + /// Decode an entity from the `byteBuffer` in postgres wire format. + /// This method has a default implementation and may be overriden + /// only for special cases, like `Optional`s. + static func decodeRaw(from byteBuffer: inout ByteBuffer?, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self +} + +extension PSQLDecodable { + @inlinable + public static func decodeRaw( + from byteBuffer: inout ByteBuffer?, + type: PostgresDataType, + format: PostgresFormat, + context: PSQLDecodingContext + ) throws -> Self { + guard var buffer = byteBuffer else { + throw PostgresCastingError.Code.missingData + } + return try self.decode(from: &buffer, type: type, format: format, context: context) + } } /// A type that can be encoded into and decoded from a postgres binary format diff --git a/Sources/PostgresNIO/New/PSQLData.swift b/Sources/PostgresNIO/New/PSQLData.swift index 9131834e..d490c78c 100644 --- a/Sources/PostgresNIO/New/PSQLData.swift +++ b/Sources/PostgresNIO/New/PSQLData.swift @@ -12,29 +12,4 @@ struct PSQLData: Equatable { self.dataType = dataType self.format = format } - - @inlinable - func decode(as: Optional.Type, context: PSQLDecodingContext) throws -> T? { - try self.decodeIfPresent(as: T.self, context: context) - } - - @inlinable - func decode(as type: T.Type, context: PSQLDecodingContext) throws -> T { - switch self.bytes { - case .none: - throw PSQLCastingError.missingData(targetType: type, type: self.dataType, context: context) - case .some(var buffer): - return try T.decode(from: &buffer, type: self.dataType, format: self.format, context: context) - } - } - - @inlinable - func decodeIfPresent(as: T.Type, context: PSQLDecodingContext) throws -> T? { - switch self.bytes { - case .none: - return nil - case .some(var buffer): - return try T.decode(from: &buffer, type: self.dataType, format: self.format, context: context) - } - } } diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index bc642e6d..42dd221e 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -18,7 +18,7 @@ struct PSQLError: Error { case connectionError(underlying: Error) case uncleanShutdown - case casting(PSQLCastingError) + case casting(PostgresCastingError) } internal var base: Base @@ -80,59 +80,59 @@ struct PSQLError: Error { } } -struct PSQLCastingError: Error { +struct PostgresCastingError: Error, Equatable { + struct Code: Hashable, Error { + enum Base { + case missingData + case typeMismatch + case failure + } + + var base: Base + + init(_ base: Base) { + self.base = base + } + + static let missingData = Self.init(.missingData) + static let typeMismatch = Self.init(.typeMismatch) + static let failure = Self.init(.failure) + } + + let code: Code let columnName: String let columnIndex: Int - - let file: String - let line: Int - - let targetType: PSQLDecodable.Type + let targetType: Any.Type let postgresType: PostgresDataType let postgresData: ByteBuffer? - let description: String - let underlying: Error? - - static func missingData(targetType: PSQLDecodable.Type, type: PostgresDataType, context: PSQLDecodingContext) -> Self { - PSQLCastingError( - columnName: context.columnName, - columnIndex: context.columnIndex, - file: context.file, - line: context.line, - targetType: targetType, - postgresType: type, - postgresData: nil, - description: """ - Failed to cast Postgres data type \(type.description) to Swift type \(targetType) \ - because of missing data in \(context.file) line \(context.line). - """, - underlying: nil - ) - } - - static func failure(targetType: PSQLDecodable.Type, - type: PostgresDataType, - postgresData: ByteBuffer, - description: String? = nil, - underlying: Error? = nil, - context: PSQLDecodingContext) -> Self - { - PSQLCastingError( - columnName: context.columnName, - columnIndex: context.columnIndex, - file: context.file, - line: context.line, - targetType: targetType, - postgresType: type, - postgresData: postgresData, - description: description ?? """ - Failed to cast Postgres data type \(type.description) to Swift type \(targetType) \ - in \(context.file) line \(context.line)." - """, - underlying: underlying - ) + var description: String { + switch self.code.base { + case .missingData: + return """ + Failed to cast Postgres data type \(self.postgresType.description) to Swift type \(self.targetType) \ + because of missing data. + """ + + case .typeMismatch: + preconditionFailure() + + case .failure: + return """ + Failed to cast Postgres data type \(self.postgresType.description) to Swift type \(self.targetType). + """ + } + + } + + static func ==(lhs: PostgresCastingError, rhs: PostgresCastingError) -> Bool { + return lhs.code == rhs.code + && lhs.columnName == rhs.columnName + && lhs.columnIndex == rhs.columnIndex + && lhs.targetType == rhs.targetType + && lhs.postgresType == rhs.postgresType + && lhs.postgresData == rhs.postgresData } } diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift index f76f9eef..dbd57c48 100644 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -54,10 +54,9 @@ extension PSQLRow { columnIndex: index, file: file, line: line) - - guard var cellSlice = self.data[column: index] else { - throw PSQLCastingError.missingData(targetType: T.self, type: column.dataType, context: context) - } + + // Safe to force unwrap here, as we have ensured above that the row has enough columns + var cellSlice = self.data[column: index]! return try T.decode(from: &cellSlice, type: column.dataType, format: column.format, context: context) } diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index 1079205e..a155399f 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -62,10 +62,9 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() XCTAssertNoThrow(try values.encode(into: &buffer, context: .forTests())) - let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) var result: [String]? - XCTAssertNoThrow(result = try data.decode(as: [String].self, context: .forTests())) + XCTAssertNoThrow(result = try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) XCTAssertEqual(values, result) } @@ -74,10 +73,9 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() XCTAssertNoThrow(try values.encode(into: &buffer, context: .forTests())) - let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) var result: [String]? - XCTAssertNoThrow(result = try data.decode(as: [String].self, context: .forTests())) + XCTAssertNoThrow(result = try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) XCTAssertEqual(values, result) } @@ -86,10 +84,9 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(2)) // invalid value buffer.writeInteger(Int32(0)) buffer.writeInteger(String.psqlArrayElementType.rawValue) - let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) - XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -98,10 +95,9 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(0)) // is empty buffer.writeInteger(Int32(1)) // invalid value, must always be 0 buffer.writeInteger(String.psqlArrayElementType.rawValue) - let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) - XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -109,10 +105,9 @@ class Array_PSQLCodableTests: XCTestCase { let value: Int64 = 1 << 32 var buffer = ByteBuffer() value.encode(into: &buffer, context: .forTests()) - let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) - - XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -123,10 +118,9 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(String.psqlArrayElementType.rawValue) buffer.writeInteger(Int32(-123)) // expected element count buffer.writeInteger(Int32(1)) // dimensions... must be one - let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) - - XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -137,10 +131,9 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(String.psqlArrayElementType.rawValue) buffer.writeInteger(Int32(1)) // expected element count buffer.writeInteger(Int32(2)) // dimensions... must be one - let data = PSQLData(bytes: buffer, dataType: .textArray, format: .binary) - XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -152,10 +145,9 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // expected element count unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 - let data = PSQLData(bytes: unexpectedEndInElementLengthBuffer, dataType: .textArray, format: .binary) - XCTAssertThrowsError(try data.decode(as: [String].self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + XCTAssertThrowsError(try [String].decode(from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } var unexpectedEndInElementBuffer = ByteBuffer() @@ -166,10 +158,9 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 unexpectedEndInElementBuffer.writeString("Hello World") // only 11 bytes, 12 needed! - let unexpectedEndInElementData = PSQLData(bytes: unexpectedEndInElementBuffer, dataType: .textArray, format: .binary) - XCTAssertThrowsError(try unexpectedEndInElementData.decode(as: [String].self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + XCTAssertThrowsError(try [String].decode(from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } } diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index f7d40834..773a35b8 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -15,10 +15,9 @@ class Bool_PSQLCodableTests: XCTestCase { XCTAssertEqual(value.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) - let data = PSQLData(bytes: buffer, dataType: .bool, format: .binary) var result: Bool? - XCTAssertNoThrow(result = try data.decode(as: Bool.self, context: .forTests())) + XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .forTests())) XCTAssertEqual(value, result) } @@ -31,30 +30,27 @@ class Bool_PSQLCodableTests: XCTestCase { XCTAssertEqual(value.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 0) - let data = PSQLData(bytes: buffer, dataType: .bool, format: .binary) - + var result: Bool? - XCTAssertNoThrow(result = try data.decode(as: Bool.self, context: .forTests())) + XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .forTests())) XCTAssertEqual(value, result) } func testBinaryDecodeBoolInvalidLength() { var buffer = ByteBuffer() buffer.writeInteger(Int64(1)) - let data = PSQLData(bytes: buffer, dataType: .bool, format: .binary) - - XCTAssertThrowsError(try data.decode(as: Bool.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } func testBinaryDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - let data = PSQLData(bytes: buffer, dataType: .bool, format: .binary) - - XCTAssertThrowsError(try data.decode(as: Bool.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -65,10 +61,9 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "t")) - let data = PSQLData(bytes: buffer, dataType: .bool, format: .text) - + var result: Bool? - XCTAssertNoThrow(result = try data.decode(as: Bool.self, context: .forTests())) + XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .text, context: .forTests())) XCTAssertEqual(value, result) } @@ -77,20 +72,18 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "f")) - let data = PSQLData(bytes: buffer, dataType: .bool, format: .text) - + var result: Bool? - XCTAssertNoThrow(result = try data.decode(as: Bool.self, context: .forTests())) + XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .text, context: .forTests())) XCTAssertEqual(value, result) } func testTextDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - let data = PSQLData(bytes: buffer, dataType: .bool, format: .text) - - XCTAssertThrowsError(try data.decode(as: Bool.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .text, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } } diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index 7d58b660..a3ad33a7 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -10,10 +10,9 @@ class Bytes_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() data.encode(into: &buffer, context: .forTests()) XCTAssertEqual(data.psqlType, .bytea) - let psqlData = PSQLData(bytes: buffer, dataType: .bytea, format: .binary) var result: Data? - XCTAssertNoThrow(result = try psqlData.decode(as: Data.self, context: .forTests())) + XCTAssertNoThrow(result = try Data.decode(from: &buffer, type: .bytea, format: .binary, context: .forTests())) XCTAssertEqual(data, result) } @@ -23,10 +22,9 @@ class Bytes_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() bytes.encode(into: &buffer, context: .forTests()) XCTAssertEqual(bytes.psqlType, .bytea) - let psqlData = PSQLData(bytes: buffer, dataType: .bytea, format: .binary) var result: ByteBuffer? - XCTAssertNoThrow(result = try psqlData.decode(as: ByteBuffer.self, context: .forTests())) + XCTAssertNoThrow(result = try ByteBuffer.decode(from: &buffer, type: .bytea, format: .binary, context: .forTests())) XCTAssertEqual(bytes, result) } diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index aae7ad8b..87eb46de 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -11,20 +11,18 @@ class Date_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .timestamptz) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .timestamptz, format: .binary) - + var result: Date? - XCTAssertNoThrow(result = try data.decode(as: Date.self, context: .forTests())) + XCTAssertNoThrow(result = try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .forTests())) XCTAssertEqual(value, result) } func testDecodeRandomDate() { var buffer = ByteBuffer() buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) - let data = PSQLData(bytes: buffer, dataType: .timestamptz, format: .binary) - + var result: Date? - XCTAssertNoThrow(result = try data.decode(as: Date.self, context: .forTests())) + XCTAssertNoThrow(result = try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .forTests())) XCTAssertNotNil(result) } @@ -32,66 +30,59 @@ class Date_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) - let data = PSQLData(bytes: buffer, dataType: .timestamptz, format: .binary) - - XCTAssertThrowsError(try data.decode(as: Date.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } func testDecodeDate() { var firstDateBuffer = ByteBuffer() firstDateBuffer.writeInteger(Int32.min) - let firstDateData = PSQLData(bytes: firstDateBuffer, dataType: .date, format: .binary) var firstDate: Date? - XCTAssertNoThrow(firstDate = try firstDateData.decode(as: Date.self, context: .forTests())) + XCTAssertNoThrow(firstDate = try Date.decode(from: &firstDateBuffer, type: .date, format: .binary, context: .forTests())) XCTAssertNotNil(firstDate) var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) - let lastDateData = PSQLData(bytes: lastDateBuffer, dataType: .date, format: .binary) - + var lastDate: Date? - XCTAssertNoThrow(lastDate = try lastDateData.decode(as: Date.self, context: .forTests())) + XCTAssertNoThrow(lastDate = try Date.decode(from: &lastDateBuffer, type: .date, format: .binary, context: .forTests())) XCTAssertNotNil(lastDate) } func testDecodeDateFromTimestamp() { var firstDateBuffer = ByteBuffer() firstDateBuffer.writeInteger(Int32.min) - let firstDateData = PSQLData(bytes: firstDateBuffer, dataType: .date, format: .binary) var firstDate: Date? - XCTAssertNoThrow(firstDate = try firstDateData.decode(as: Date.self, context: .forTests())) + XCTAssertNoThrow(firstDate = try Date.decode(from: &firstDateBuffer, type: .date, format: .binary, context: .forTests())) XCTAssertNotNil(firstDate) var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) - let lastDateData = PSQLData(bytes: lastDateBuffer, dataType: .date, format: .binary) var lastDate: Date? - XCTAssertNoThrow(lastDate = try lastDateData.decode(as: Date.self, context: .forTests())) + XCTAssertNoThrow(lastDate = try Date.decode(from: &lastDateBuffer, type: .date, format: .binary, context: .forTests())) XCTAssertNotNil(lastDate) } func testDecodeDateFailsWithToMuchData() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - let data = PSQLData(bytes: buffer, dataType: .date, format: .binary) - - XCTAssertThrowsError(try data.decode(as: Date.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try Date.decode(from: &buffer, type: .date, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } func testDecodeDateFailsWithWrongDataType() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - let data = PSQLData(bytes: buffer, dataType: .int8, format: .binary) - - XCTAssertThrowsError(try data.decode(as: Date.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try Date.decode(from: &buffer, type: .int8, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift index afdcad20..8348c848 100644 --- a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift @@ -11,10 +11,9 @@ class Decimal_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .numeric) - let data = PSQLData(bytes: buffer, dataType: .numeric, format: .binary) - + var result: Decimal? - XCTAssertNoThrow(result = try data.decode(as: Decimal.self, context: .forTests())) + XCTAssertNoThrow(result = try Decimal.decode(from: &buffer, type: .numeric, format: .binary, context: .forTests())) XCTAssertEqual(value, result) } } @@ -22,10 +21,9 @@ class Decimal_PSQLCodableTests: XCTestCase { func testDecodeFailureInvalidType() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - let data = PSQLData(bytes: buffer, dataType: .int8, format: .binary) - XCTAssertThrowsError(try data.decode(as: Decimal.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + XCTAssertThrowsError(try Decimal.decode(from: &buffer, type: .int8, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift index 33b8c0da..108b99ec 100644 --- a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift @@ -12,10 +12,9 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .float8, format: .binary) - + var result: Double? - XCTAssertNoThrow(result = try data.decode(as: Double.self, context: .forTests())) + XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .forTests())) XCTAssertEqual(value, result) } } @@ -28,10 +27,9 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float4) XCTAssertEqual(buffer.readableBytes, 4) - let data = PSQLData(bytes: buffer, dataType: .float4, format: .binary) - + var result: Float? - XCTAssertNoThrow(result = try data.decode(as: Float.self, context: .forTests())) + XCTAssertNoThrow(result = try Float.decode(from: &buffer, type: .float4, format: .binary, context: .forTests())) XCTAssertEqual(value, result) } } @@ -43,10 +41,9 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .float8, format: .binary) - + var result: Double? - XCTAssertNoThrow(result = try data.decode(as: Double.self, context: .forTests())) + XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .forTests())) XCTAssertEqual(result?.isNaN, true) } @@ -57,10 +54,9 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .float8, format: .binary) - + var result: Double? - XCTAssertNoThrow(result = try data.decode(as: Double.self, context: .forTests())) + XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .forTests())) XCTAssertEqual(result?.isInfinite, true) } @@ -72,10 +68,9 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float4) XCTAssertEqual(buffer.readableBytes, 4) - let data = PSQLData(bytes: buffer, dataType: .float4, format: .binary) - + var result: Double? - XCTAssertNoThrow(result = try data.decode(as: Double.self, context: .forTests())) + XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float4, format: .binary, context: .forTests())) XCTAssertEqual(result, Double(value)) } } @@ -88,10 +83,9 @@ class Float_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .forTests()) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) - let data = PSQLData(bytes: buffer, dataType: .float8, format: .binary) - + var result: Float? - XCTAssertNoThrow(result = try data.decode(as: Float.self, context: .forTests())) + XCTAssertNoThrow(result = try Float.decode(from: &buffer, type: .float8, format: .binary, context: .forTests())) XCTAssertEqual(result, Float(value)) } } @@ -101,38 +95,40 @@ class Float_PSQLCodableTests: XCTestCase { eightByteBuffer.writeInteger(Int64(0)) var fourByteBuffer = ByteBuffer() fourByteBuffer.writeInteger(Int32(0)) - let toLongData = PSQLData(bytes: eightByteBuffer, dataType: .float4, format: .binary) - let toShortData = PSQLData(bytes: fourByteBuffer, dataType: .float8, format: .binary) - - XCTAssertThrowsError(try toLongData.decode(as: Double.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + var toLongBuffer1 = eightByteBuffer + XCTAssertThrowsError(try Double.decode(from: &toLongBuffer1, type: .float4, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } - - XCTAssertThrowsError(try toLongData.decode(as: Float.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + var toLongBuffer2 = eightByteBuffer + XCTAssertThrowsError(try Float.decode(from: &toLongBuffer2, type: .float4, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } - - XCTAssertThrowsError(try toShortData.decode(as: Double.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + var toShortBuffer1 = fourByteBuffer + XCTAssertThrowsError(try Double.decode(from: &toShortBuffer1, type: .float8, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } - - XCTAssertThrowsError(try toShortData.decode(as: Float.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + var toShortBuffer2 = fourByteBuffer + XCTAssertThrowsError(try Float.decode(from: &toShortBuffer2, type: .float8, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } func testDecodeFailureInvalidType() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - let data = PSQLData(bytes: buffer, dataType: .int8, format: .binary) - - XCTAssertThrowsError(try data.decode(as: Double.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + var copy1 = buffer + XCTAssertThrowsError(try Double.decode(from: ©1, type: .int8, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } - - XCTAssertThrowsError(try data.decode(as: Float.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + var copy2 = buffer + XCTAssertThrowsError(try Float.decode(from: ©2, type: .int8, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } - } diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index c9180016..40bf3f34 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -20,20 +20,18 @@ class JSON_PSQLCodableTests: XCTestCase { // verify jsonb prefix byte XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) - - let data = PSQLData(bytes: buffer, dataType: .jsonb, format: .binary) + var result: Hello? - XCTAssertNoThrow(result = try data.decode(as: Hello.self, context: .forTests())) + XCTAssertNoThrow(result = try Hello.decode(from: &buffer, type: .jsonb, format: .binary, context: .forTests())) XCTAssertEqual(result, hello) } func testDecodeFromJSON() { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - - let data = PSQLData(bytes: buffer, dataType: .json, format: .binary) + var result: Hello? - XCTAssertNoThrow(result = try data.decode(as: Hello.self, context: .forTests())) + XCTAssertNoThrow(result = try Hello.decode(from: &buffer, type: .json, format: .binary, context: .forTests())) XCTAssertEqual(result, Hello(name: "world")) } @@ -45,9 +43,9 @@ class JSON_PSQLCodableTests: XCTestCase { buffer.writeString(#"{"hello":"world"}"#) for (format, dataType) in combinations { - let data = PSQLData(bytes: buffer, dataType: dataType, format: format) + var loopBuffer = buffer var result: Hello? - XCTAssertNoThrow(result = try data.decode(as: Hello.self, context: .forTests())) + XCTAssertNoThrow(result = try Hello.decode(from: &loopBuffer, type: dataType, format: format, context: .forTests())) XCTAssertEqual(result, Hello(name: "world")) } } @@ -55,20 +53,18 @@ class JSON_PSQLCodableTests: XCTestCase { func testDecodeFromJSONBWithoutVersionPrefixByte() { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - - let data = PSQLData(bytes: buffer, dataType: .jsonb, format: .binary) - XCTAssertThrowsError(try data.decode(as: Hello.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try Hello.decode(from: &buffer, type: .jsonb, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } func testDecodeFromJSONBWithWrongDataType() { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - - let data = PSQLData(bytes: buffer, dataType: .text, format: .binary) - XCTAssertThrowsError(try data.decode(as: Hello.self, context: .forTests())) { error in - XCTAssert(error is PSQLCastingError) + + XCTAssertThrowsError(try Hello.decode(from: &buffer, type: .text, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift index ead0a1b4..62dbb9d7 100644 --- a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift @@ -8,27 +8,28 @@ class Optional_PSQLCodableTests: XCTestCase { let value: String? = "Hello World" var buffer = ByteBuffer() - value?.encode(into: &buffer, context: .forTests()) + XCTAssertNoThrow(try value.encodeRaw(into: &buffer, context: .forTests())) XCTAssertEqual(value.psqlType, .text) - let data = PSQLData(bytes: buffer, dataType: .text, format: .binary) - + XCTAssertEqual(buffer.readInteger(as: Int32.self), 11) + var result: String? - XCTAssertNoThrow(result = try data.decode(as: String?.self, context: .forTests())) + var optBuffer: ByteBuffer? = buffer + XCTAssertNoThrow(result = try String?.decodeRaw(from: &optBuffer, type: .text, format: .binary, context: .forTests())) XCTAssertEqual(result, value) } func testRoundTripNoneString() { let value: Optional = .none - + var buffer = ByteBuffer() - value?.encode(into: &buffer, context: .forTests()) - XCTAssertEqual(buffer.readableBytes, 0) + XCTAssertNoThrow(try value.encodeRaw(into: &buffer, context: .forTests())) + XCTAssertEqual(buffer.readableBytes, 4) + XCTAssertEqual(buffer.getInteger(at: 0, as: Int32.self), -1) XCTAssertEqual(value.psqlType, .null) - - let data = PSQLData(bytes: nil, dataType: .text, format: .binary) - + var result: String? - XCTAssertNoThrow(result = try data.decode(as: String?.self, context: .forTests())) + var inBuffer: ByteBuffer? = nil + XCTAssertNoThrow(result = try String?.decodeRaw(from: &inBuffer, type: .text, format: .binary, context: .forTests())) XCTAssertEqual(result, value) } @@ -41,10 +42,10 @@ class Optional_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(try encodable.encodeRaw(into: &buffer, context: .forTests())) XCTAssertEqual(buffer.readableBytes, 20) XCTAssertEqual(buffer.readInteger(as: Int32.self), 16) - let data = PSQLData(bytes: buffer, dataType: .uuid, format: .binary) - + var result: UUID? - XCTAssertNoThrow(result = try data.decode(as: UUID?.self, context: .forTests())) + var optBuffer: ByteBuffer? = buffer + XCTAssertNoThrow(result = try UUID?.decodeRaw(from: &optBuffer, type: .uuid, format: .binary, context: .forTests())) XCTAssertEqual(result, value) } @@ -57,11 +58,10 @@ class Optional_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(try encodable.encodeRaw(into: &buffer, context: .forTests())) XCTAssertEqual(buffer.readableBytes, 4) XCTAssertEqual(buffer.readInteger(as: Int32.self), -1) - - let data = PSQLData(bytes: nil, dataType: .uuid, format: .binary) - + var result: UUID? - XCTAssertNoThrow(result = try data.decode(as: UUID?.self, context: .forTests())) + var inBuffer: ByteBuffer? = nil + XCTAssertNoThrow(result = try UUID?.decodeRaw(from: &inBuffer, type: .text, format: .binary, context: .forTests())) XCTAssertEqual(result, value) } } diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift index cf233890..712d8843 100644 --- a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -18,10 +18,9 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(try value.encode(into: &buffer, context: .forTests())) XCTAssertEqual(value.psqlType, Int16.psqlArrayElementType) XCTAssertEqual(buffer.readableBytes, 2) - let data = PSQLData(bytes: buffer, dataType: Int16.psqlArrayElementType, format: .binary) - + var result: MyRawRepresentable? - XCTAssertNoThrow(result = try data.decode(as: MyRawRepresentable.self, context: .forTests())) + XCTAssertNoThrow(result = try MyRawRepresentable.decode(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .forTests())) XCTAssertEqual(value, result) } } @@ -29,24 +28,18 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { func testDecodeInvalidRawTypeValue() { var buffer = ByteBuffer() buffer.writeInteger(Int16(4)) // out of bounds - let data = PSQLData(bytes: buffer, dataType: Int16.psqlArrayElementType, format: .binary) - - XCTAssertThrowsError(try data.decode(as: MyRawRepresentable.self, context: .forTests())) { error in - XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) - XCTAssertEqual((error as? PSQLCastingError)?.file, #file) - XCTAssert((error as? PSQLCastingError)?.targetType == MyRawRepresentable.self) + + XCTAssertThrowsError(try MyRawRepresentable.decode(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } func testDecodeInvalidUnderlyingTypeValue() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // out of bounds - let data = PSQLData(bytes: buffer, dataType: Int32.psqlArrayElementType, format: .binary) - - XCTAssertThrowsError(try data.decode(as: MyRawRepresentable.self, context: .forTests())) { error in - XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) - XCTAssertEqual((error as? PSQLCastingError)?.file, #file) - XCTAssert((error as? PSQLCastingError)?.targetType == MyRawRepresentable.self) + + XCTAssertThrowsError(try MyRawRepresentable.decode(from: &buffer, type: Int32.psqlArrayElementType, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index f9d5b03d..12d9d9e2 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -37,27 +37,8 @@ class String_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer - XCTAssertThrowsError(try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) { error in - XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) - XCTAssertEqual((error as? PSQLCastingError)?.file, #file) - - XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, loopBuffer) - } - } - } - - func testDecodeFailureFromNoData() { - let dataTypes: [PostgresDataType] = [.text, .varchar, .name] - - for dataType in dataTypes { - let data = PSQLData(bytes: nil, dataType: dataType, format: .binary) - XCTAssertThrowsError(try data.decode(as: String.self, context: .forTests())) { error in - XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) - XCTAssertEqual((error as? PSQLCastingError)?.file, #file) - - XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, nil) + XCTAssertThrowsError(try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } } @@ -79,12 +60,8 @@ class String_PSQLCodableTests: XCTestCase { // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - XCTAssertThrowsError(try String.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) { error in - XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) - XCTAssertEqual((error as? PSQLCastingError)?.file, #file) - - XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, buffer) + XCTAssertThrowsError(try String.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } } diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 9c639d98..5add881a 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -83,11 +83,7 @@ class UUID_PSQLCodableTests: XCTestCase { buffer.moveReaderIndex(forwardBy: 1) XCTAssertThrowsError(try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) { error in - XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) - XCTAssertEqual((error as? PSQLCastingError)?.file, #file) - - XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, buffer) + XCTAssertEqual(error as? PostgresCastingError.Code, .failure) } } @@ -102,12 +98,8 @@ class UUID_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer - XCTAssertThrowsError(try UUID.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) { error in - XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) - XCTAssertEqual((error as? PSQLCastingError)?.file, #file) - - XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, loopBuffer) + XCTAssertThrowsError(try UUID.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } } @@ -120,14 +112,9 @@ class UUID_PSQLCodableTests: XCTestCase { let dataTypes: [PostgresDataType] = [.bool, .int8, .int2, .int4Array] for dataType in dataTypes { - let data = PSQLData(bytes: buffer, dataType: dataType, format: .binary) - - XCTAssertThrowsError(try data.decode(as: UUID.self, context: .forTests())) { error in - XCTAssertEqual((error as? PSQLCastingError)?.line, #line - 1) - XCTAssertEqual((error as? PSQLCastingError)?.file, #file) - - XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, data.bytes) + var copy = buffer + XCTAssertThrowsError(try UUID.decode(from: ©, type: dataType, format: .binary, context: .forTests())) { + XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } } diff --git a/Tests/PostgresNIOTests/New/PSQLDataTests.swift b/Tests/PostgresNIOTests/New/PSQLDataTests.swift deleted file mode 100644 index c76b8d07..00000000 --- a/Tests/PostgresNIOTests/New/PSQLDataTests.swift +++ /dev/null @@ -1,18 +0,0 @@ -import NIOCore -import XCTest -@testable import PostgresNIO - -class PSQLDataTests: XCTestCase { - func testStringDecoding() { - let emptyBuffer: ByteBuffer? = nil - - let data = PSQLData(bytes: emptyBuffer, dataType: .text, format: .binary) - - var emptyResult: String? - XCTAssertNoThrow(emptyResult = try data.decodeIfPresent(as: String.self, context: .forTests())) - XCTAssertNil(emptyResult) - - XCTAssertNoThrow(emptyResult = try data.decode(as: String?.self, context: .forTests())) - XCTAssertNil(emptyResult) - } -} diff --git a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift new file mode 100644 index 00000000..697933ea --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift @@ -0,0 +1,28 @@ +@testable import PostgresNIO +import XCTest + +final class PostgresCastingErrorTests: XCTestCase { + func testPostgresCastingErrorEquality() { + let error1 = PostgresCastingError( + code: .typeMismatch, + columnName: "column", + columnIndex: 0, + targetType: String.self, + postgresType: .text, + postgresData: ByteBuffer(string: "hello world") + ) + + let error2 = PostgresCastingError( + code: .typeMismatch, + columnName: "column", + columnIndex: 0, + targetType: Int.self, + postgresType: .text, + postgresData: ByteBuffer(string: "hello world") + ) + + XCTAssertNotEqual(error1, error2) + let error3 = error1 + XCTAssertEqual(error1, error3) + } +} From 3ee12457bd9fe39db81d4877170372f87ad4b11a Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 17 Feb 2022 23:53:43 +0100 Subject: [PATCH 051/292] Add async option to PSQLRowStream (#206) --- Sources/PostgresNIO/New/PSQLRowSequence.swift | 606 ++++++++++++++++++ Sources/PostgresNIO/New/PSQLRowStream.swift | 103 +++ .../New/PSQLRowSequenceTests.swift | 466 ++++++++++++++ 3 files changed, 1175 insertions(+) create mode 100644 Sources/PostgresNIO/New/PSQLRowSequence.swift create mode 100644 Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift diff --git a/Sources/PostgresNIO/New/PSQLRowSequence.swift b/Sources/PostgresNIO/New/PSQLRowSequence.swift new file mode 100644 index 00000000..17ba1659 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLRowSequence.swift @@ -0,0 +1,606 @@ +import NIOCore +import NIOConcurrencyHelpers + +#if swift(>=5.5) && canImport(_Concurrency) +/// An async sequence of ``PSQLRow``s. +/// +/// - Note: This is a struct to allow us to move to a move only type easily once they become available. +struct PSQLRowSequence: AsyncSequence { + typealias Element = PSQLRow + typealias AsyncIterator = Iterator + + final class _Internal { + + let consumer: AsyncStreamConsumer + + init(consumer: AsyncStreamConsumer) { + self.consumer = consumer + } + + deinit { + // if no iterator was created, we need to cancel the stream + self.consumer.sequenceDeinitialized() + } + + func makeAsyncIterator() -> Iterator { + self.consumer.makeAsyncIterator() + } + } + + let _internal: _Internal + + init(_ consumer: AsyncStreamConsumer) { + self._internal = .init(consumer: consumer) + } + + func makeAsyncIterator() -> Iterator { + self._internal.makeAsyncIterator() + } +} + +extension PSQLRowSequence { + struct Iterator: AsyncIteratorProtocol { + typealias Element = PSQLRow + + let _internal: _Internal + + init(consumer: AsyncStreamConsumer) { + self._internal = _Internal(consumer: consumer) + } + + mutating func next() async throws -> PSQLRow? { + try await self._internal.next() + } + + final class _Internal { + let consumer: AsyncStreamConsumer + + init(consumer: AsyncStreamConsumer) { + self.consumer = consumer + } + + deinit { + self.consumer.iteratorDeinitialized() + } + + func next() async throws -> PSQLRow? { + try await self.consumer.next() + } + } + } +} + +final class AsyncStreamConsumer { + let lock = Lock() + + let lookupTable: [String: Int] + let columns: [RowDescription.Column] + private var state: StateMachine + + init( + lookupTable: [String: Int], + columns: [RowDescription.Column] + ) { + self.state = StateMachine() + + self.lookupTable = lookupTable + self.columns = columns + } + + func startCompleted(_ buffer: CircularBuffer, commandTag: String) { + self.lock.withLock { + self.state.finished(buffer, commandTag: commandTag) + } + } + + func startStreaming(_ buffer: CircularBuffer, upstream: PSQLRowStream) { + self.lock.withLock { + self.state.buffered(buffer, upstream: upstream) + } + } + + func startFailed(_ error: Error) { + self.lock.withLock { + self.state.failed(error) + } + } + + func receive(_ newRows: [DataRow]) { + let receiveAction = self.lock.withLock { + self.state.receive(newRows) + } + + switch receiveAction { + case .succeed(let continuation, let data, signalDemandTo: let source): + let row = PSQLRow( + data: data, + lookupTable: self.lookupTable, + columns: self.columns + ) + continuation.resume(returning: row) + source?.demand() + + case .none: + break + } + } + + func receive(completion result: Result) { + let completionAction = self.lock.withLock { + self.state.receive(completion: result) + } + + switch completionAction { + case .succeed(let continuation): + continuation.resume(returning: nil) + + case .fail(let continuation, let error): + continuation.resume(throwing: error) + + case .none: + break + } + } + + func sequenceDeinitialized() { + let action = self.lock.withLock { + self.state.sequenceDeinitialized() + } + + switch action { + case .cancelStream(let source): + source.cancel() + case .none: + break + } + } + + func makeAsyncIterator() -> PSQLRowSequence.Iterator { + self.lock.withLock { + self.state.createAsyncIterator() + } + let iterator = PSQLRowSequence.Iterator(consumer: self) + return iterator + } + + func iteratorDeinitialized() { + let action = self.lock.withLock { + self.state.iteratorDeinitialized() + } + + switch action { + case .cancelStream(let source): + source.cancel() + case .none: + break + } + } + + func next() async throws -> PSQLRow? { + self.lock.lock() + switch self.state.next() { + case .returnNil: + self.lock.unlock() + return nil + + case .returnRow(let data, signalDemandTo: let source): + self.lock.unlock() + source?.demand() + return PSQLRow( + data: data, + lookupTable: self.lookupTable, + columns: self.columns + ) + + case .throwError(let error): + self.lock.unlock() + throw error + + case .hitSlowPath: + return try await withCheckedThrowingContinuation { continuation in + let slowPathAction = self.state.next(for: continuation) + self.lock.unlock() + switch slowPathAction { + case .signalDemand(let source): + source.demand() + case .none: + break + } + } + } + } + +} + +extension AsyncStreamConsumer { + struct StateMachine { + enum UpstreamState { + enum DemandState { + case canAskForMore + case waitingForMore(CheckedContinuation?) + } + + case initialized + /// The upstream has more data that can be received + case streaming(AdaptiveRowBuffer, PSQLRowStream, DemandState) + /// The upstream has finished, but the downstream has not consumed all events. + case finished(AdaptiveRowBuffer, String) + /// The upstream has failed, but the downstream has not consumed the error yet. + case failed(Error) + /// The upstream has failed or finished and the downstream has consumed all events. Final state. + case consumed + + /// A state used to prevent CoW allocations when modifying an internal struct in the + /// `.streaming` or `.finished` state. + case modifying + } + + enum DownstreamState { + case sequenceCreated + case iteratorCreated + } + + var upstreamState = UpstreamState.initialized + var downstreamState = DownstreamState.sequenceCreated + + init() {} + + mutating func buffered(_ buffer: CircularBuffer, upstream: PSQLRowStream) { + switch self.upstreamState { + case .initialized: + let adaptive = AdaptiveRowBuffer(buffer) + self.upstreamState = .streaming(adaptive, upstream, buffer.isEmpty ? .waitingForMore(nil) : .canAskForMore) + + case .streaming, .finished, .failed, .consumed, .modifying: + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + } + } + + mutating func finished(_ buffer: CircularBuffer, commandTag: String) { + switch self.upstreamState { + case .initialized: + let adaptive = AdaptiveRowBuffer(buffer) + self.upstreamState = .finished(adaptive, commandTag) + + case .streaming, .finished, .failed, .consumed, .modifying: + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + } + } + + mutating func failed(_ error: Error) { + switch self.upstreamState { + case .initialized: + self.upstreamState = .failed(error) + + case .streaming, .finished, .failed, .consumed, .modifying: + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + } + } + + mutating func createAsyncIterator() { + switch self.downstreamState { + case .sequenceCreated: + self.downstreamState = .iteratorCreated + case .iteratorCreated: + preconditionFailure("An iterator already exists") + } + } + + enum SequenceDeinitializedAction { + case cancelStream(PSQLRowStream) + case none + } + + mutating func sequenceDeinitialized() -> SequenceDeinitializedAction { + switch (self.downstreamState, self.upstreamState) { + case (.sequenceCreated, .initialized): + preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") + + case (.sequenceCreated, .streaming(_, let source, _)): + return .cancelStream(source) + + case (.sequenceCreated, .finished), + (.sequenceCreated, .consumed), + (.sequenceCreated, .failed): + return .none + + case (.iteratorCreated, _): + return .none + + case (_, .modifying): + preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") + } + } + + mutating func iteratorDeinitialized() -> SequenceDeinitializedAction { + switch (self.downstreamState, self.upstreamState) { + case (.sequenceCreated, _), + (.iteratorCreated, .initialized): + preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") + + case (.iteratorCreated, .streaming(_, let source, _)): + return .cancelStream(source) + + case (.iteratorCreated, .finished), + (.iteratorCreated, .consumed), + (.iteratorCreated, .failed): + return .none + + case (_, .modifying): + preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") + } + } + + + enum NextFastPathAction { + case hitSlowPath + case throwError(Error) + case returnRow(DataRow, signalDemandTo: PSQLRowStream?) + case returnNil + } + + mutating func next() -> NextFastPathAction { + switch self.upstreamState { + case .initialized: + preconditionFailure() + + case .streaming(var buffer, let source, .canAskForMore): + self.upstreamState = .modifying + guard let (data, demand) = buffer.popFirst() else { + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .hitSlowPath + } + if demand { + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) + return .returnRow(data, signalDemandTo: source) + } + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .returnRow(data, signalDemandTo: nil) + + case .streaming(var buffer, let source, .waitingForMore(.none)): + self.upstreamState = .modifying + guard let (data, _) = buffer.popFirst() else { + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) + return .hitSlowPath + } + + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) + return .returnRow(data, signalDemandTo: nil) + + case .streaming(_, _, .waitingForMore(.some)): + preconditionFailure() + + case .finished(var buffer, let commandTag): + self.upstreamState = .modifying + guard let (data, _) = buffer.popFirst() else { + self.upstreamState = .consumed + return .returnNil + } + + self.upstreamState = .finished(buffer, commandTag) + return .returnRow(data, signalDemandTo: nil) + + case .failed(let error): + self.upstreamState = .consumed + return .throwError(error) + + case .consumed: + return .returnNil + + case .modifying: + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + } + } + + enum NextSlowPathAction { + case signalDemand(PSQLRowStream) + case none + } + + mutating func next(for continuation: CheckedContinuation) -> NextSlowPathAction { + switch self.upstreamState { + case .initialized: + preconditionFailure() + + case .streaming(let buffer, let source, .canAskForMore): + precondition(buffer.isEmpty) + self.upstreamState = .streaming(buffer, source, .waitingForMore(continuation)) + return .signalDemand(source) + + case .streaming(let buffer, let source, .waitingForMore(.none)): + precondition(buffer.isEmpty) + self.upstreamState = .streaming(buffer, source, .waitingForMore(continuation)) + return .none + + case .streaming(_, _, .waitingForMore(.some)), + .finished, + .failed, + .consumed: + preconditionFailure("Expected that state was already handled by fast path. Invalid upstream state: \(self.upstreamState)") + + case .modifying: + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + } + } + + enum ReceiveAction { + case succeed(CheckedContinuation, DataRow, signalDemandTo: PSQLRowStream?) + case none + } + + mutating func receive(_ newRows: [DataRow]) -> ReceiveAction { + precondition(!newRows.isEmpty) + + switch self.upstreamState { + case .streaming(var buffer, let source, .waitingForMore(.some(let continuation))): + buffer.append(contentsOf: newRows) + let (first, demand) = buffer.removeFirst() + if demand { + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) + return .succeed(continuation, first, signalDemandTo: source) + } + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .succeed(continuation, first, signalDemandTo: nil) + + case .streaming(var buffer, let source, .waitingForMore(.none)): + buffer.append(contentsOf: newRows) + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .none + + case .streaming(var buffer, let source, .canAskForMore): + buffer.append(contentsOf: newRows) + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .none + + case .initialized, .finished, .consumed: + preconditionFailure() + + case .failed: + return .none + + case .modifying: + preconditionFailure() + } + } + + enum CompletionResult { + case succeed(CheckedContinuation) + case fail(CheckedContinuation, Error) + case none + } + + mutating func receive(completion result: Result) -> CompletionResult { + switch result { + case .success(let commandTag): + return self.receiveEnd(commandTag: commandTag) + case .failure(let error): + return self.receiveError(error) + } + } + + private mutating func receiveEnd(commandTag: String) -> CompletionResult { + switch self.upstreamState { + case .streaming(let buffer, _, .waitingForMore(.some(let continuation))): + precondition(buffer.isEmpty) + self.upstreamState = .consumed + return .succeed(continuation) + + case .streaming(let buffer, _, .waitingForMore(.none)): + self.upstreamState = .finished(buffer, commandTag) + return .none + + case .streaming(let buffer, _, .canAskForMore): + self.upstreamState = .finished(buffer, commandTag) + return .none + + case .initialized, .finished, .consumed: + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + + case .failed: + return .none + + case .modifying: + preconditionFailure() + } + } + + private mutating func receiveError(_ error: Error) -> CompletionResult { + switch self.upstreamState { + case .streaming(let buffer, _, .waitingForMore(.some(let continuation))): + precondition(buffer.isEmpty) + self.upstreamState = .consumed + return .fail(continuation, error) + + case .streaming(_, _, .waitingForMore(.none)): + self.upstreamState = .failed(error) + return .none + + case .streaming(_, _, .canAskForMore): + self.upstreamState = .failed(error) + return .none + + case .initialized, .finished, .consumed: + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + + case .failed: + return .none + + case .modifying: + preconditionFailure() + } + } + } +} + +extension PSQLRowSequence { + func collect() async throws -> [PSQLRow] { + var result = [PSQLRow]() + for try await row in self { + result.append(row) + } + return result + } +} + +struct AdaptiveRowBuffer { + static let defaultBufferTarget = 256 + static let defaultBufferMinimum = 1 + static let defaultBufferMaximum = 16384 + + let minimum: Int + let maximum: Int + + private var circularBuffer: CircularBuffer + private var target: Int + private var canShrink: Bool = false + + var isEmpty: Bool { + self.circularBuffer.isEmpty + } + + init(minimum: Int, maximum: Int, target: Int, buffer: CircularBuffer) { + precondition(minimum <= target && target <= maximum) + self.minimum = minimum + self.maximum = maximum + self.target = target + self.circularBuffer = buffer + } + + init(_ circularBuffer: CircularBuffer) { + self.init( + minimum: Self.defaultBufferMinimum, + maximum: Self.defaultBufferMaximum, + target: Self.defaultBufferTarget, + buffer: circularBuffer + ) + } + + mutating func append(contentsOf newRows: Rows) where Rows.Element == DataRow { + self.circularBuffer.append(contentsOf: newRows) + if self.circularBuffer.count >= self.target, self.canShrink, self.target > self.minimum { + self.target &>>= 1 + } + self.canShrink = true + } + + /// Returns the next row in the FIFO buffer and a `bool` signalling if new rows should be loaded. + mutating func removeFirst() -> (DataRow, Bool) { + let element = self.circularBuffer.removeFirst() + + // If the buffer is drained now, we should double our target size. + if self.circularBuffer.count == 0, self.target < self.maximum { + self.target = self.target * 2 + self.canShrink = false + } + + return (element, self.circularBuffer.count < self.target) + } + + mutating func popFirst() -> (DataRow, Bool)? { + guard !self.circularBuffer.isEmpty else { + return nil + } + return self.removeFirst() + } +} +#endif diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 3262e995..d6aea9a1 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -21,6 +21,10 @@ final class PSQLRowStream { case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PSQLRow], EventLoopPromise<[PSQLRow]>, PSQLRowsDataSource) case consumed(Result) + + #if swift(>=5.5) && canImport(_Concurrency) + case asyncSequence(AsyncStreamConsumer, PSQLRowsDataSource) + #endif } internal let rowDescription: [RowDescription.Column] @@ -56,7 +60,89 @@ final class PSQLRowStream { } self.lookupTable = lookup } + + // MARK: Async Sequence + + #if swift(>=5.5) && canImport(_Concurrency) + func asyncSequence() -> PSQLRowSequence { + self.eventLoop.preconditionInEventLoop() + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") + } + + let consumer = AsyncStreamConsumer( + lookupTable: self.lookupTable, + columns: self.rowDescription + ) + + switch bufferState { + case .streaming(let bufferedRows, let dataSource): + consumer.startStreaming(bufferedRows, upstream: self) + self.downstreamState = .asyncSequence(consumer, dataSource) + + case .finished(let buffer, let commandTag): + consumer.startCompleted(buffer, commandTag: commandTag) + self.downstreamState = .consumed(.success(commandTag)) + + case .failure(let error): + consumer.startFailed(error) + self.downstreamState = .consumed(.failure(error)) + } + + return PSQLRowSequence(consumer) + } + + func demand() { + if self.eventLoop.inEventLoop { + self.demand0() + } else { + self.eventLoop.execute { + self.demand0() + } + } + } + + private func demand0() { + switch self.downstreamState { + case .waitingForConsumer, .iteratingRows, .waitingForAll: + preconditionFailure("Invalid state: \(self.downstreamState)") + + case .consumed: + break + + case .asyncSequence(_, let dataSource): + dataSource.request(for: self) + } + } + + func cancel() { + if self.eventLoop.inEventLoop { + self.cancel0() + } else { + self.eventLoop.execute { + self.cancel0() + } + } + } + + private func cancel0() { + switch self.downstreamState { + case .asyncSequence(let consumer, let dataSource): + let error = PSQLError.connectionClosed + self.downstreamState = .consumed(.failure(error)) + consumer.receive(completion: .failure(error)) + dataSource.cancel(for: self) + + case .consumed: + return + + case .waitingForConsumer, .iteratingRows, .waitingForAll: + preconditionFailure("Invalid state: \(self.downstreamState)") + } + } + #endif + // MARK: Consume in array func all() -> EventLoopFuture<[PSQLRow]> { @@ -217,6 +303,11 @@ final class PSQLRowStream { self.downstreamState = .waitingForAll(rows, promise, dataSource) // immediately request more dataSource.request(for: self) + + #if swift(>=5.5) && canImport(_Concurrency) + case .asyncSequence(let consumer, _): + consumer.receive(newRows) + #endif case .consumed(.success): preconditionFailure("How can we receive further rows, if we are supposed to be done") @@ -253,6 +344,12 @@ final class PSQLRowStream { self.downstreamState = .consumed(.success(commandTag)) promise.succeed(rows) + #if swift(>=5.5) && canImport(_Concurrency) + case .asyncSequence(let consumer, _): + consumer.receive(completion: .success(commandTag)) + self.downstreamState = .consumed(.success(commandTag)) + #endif + case .consumed: break } @@ -274,6 +371,12 @@ final class PSQLRowStream { self.downstreamState = .consumed(.failure(error)) promise.fail(error) + #if swift(>=5.5) && canImport(_Concurrency) + case .asyncSequence(let consumer, _): + consumer.receive(completion: .failure(error)) + self.downstreamState = .consumed(.failure(error)) + #endif + case .consumed: break } diff --git a/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift new file mode 100644 index 00000000..d3dd9665 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift @@ -0,0 +1,466 @@ +import NIOEmbedded +import NIOConcurrencyHelpers +import Dispatch +import XCTest +@testable import PostgresNIO + +#if swift(>=5.5.2) +final class PSQLRowSequenceTests: XCTestCase { + + func testBackpressureWorks() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + let rowSequence = stream.asyncSequence() + XCTAssertEqual(dataSource.requestCount, 0) + let dataRow: DataRow = [ByteBuffer(integer: Int64(1))] + stream.receive([dataRow]) + + var iterator = rowSequence.makeAsyncIterator() + let row = try await iterator.next() + XCTAssertEqual(dataSource.requestCount, 1) + XCTAssertEqual(row?.data, dataRow) + + stream.receive(completion: .success("SELECT 1")) + let empty = try await iterator.next() + XCTAssertNil(empty) + } + + func testCancellationWorksWhileIterating() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + let rowSequence = stream.asyncSequence() + XCTAssertEqual(dataSource.requestCount, 0) + let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + + var counter = 0 + for try await row in rowSequence { + XCTAssertEqual(try row.decode(column: 0, as: Int.self), counter) + counter += 1 + + if counter == 64 { + break + } + } + + XCTAssertEqual(dataSource.cancelCount, 1) + } + + func testCancellationWorksBeforeIterating() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + let rowSequence = stream.asyncSequence() + XCTAssertEqual(dataSource.requestCount, 0) + let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + + var iterator: PSQLRowSequence.Iterator? = rowSequence.makeAsyncIterator() + iterator = nil + + XCTAssertEqual(dataSource.cancelCount, 1) + XCTAssertNil(iterator, "Surpress warning") + } + + func testDroppingTheSequenceCancelsTheSource() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + var rowSequence: PSQLRowSequence? = stream.asyncSequence() + rowSequence = nil + + XCTAssertEqual(dataSource.cancelCount, 1) + XCTAssertNil(rowSequence, "Surpress warning") + } + + func testStreamBasedOnCompletedQuery() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + let rowSequence = stream.asyncSequence() + let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + stream.receive(completion: .success("SELECT 128")) + + var counter = 0 + for try await row in rowSequence { + XCTAssertEqual(try row.decode(column: 0, as: Int.self), counter) + counter += 1 + } + + XCTAssertEqual(dataSource.cancelCount, 0) + } + + func testStreamIfInitializedWithAllData() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + stream.receive(completion: .success("SELECT 128")) + + let rowSequence = stream.asyncSequence() + + var counter = 0 + for try await row in rowSequence { + XCTAssertEqual(try row.decode(column: 0, as: Int.self), counter) + counter += 1 + } + + XCTAssertEqual(dataSource.cancelCount, 0) + } + + func testStreamIfInitializedWithError() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + stream.receive(completion: .failure(PSQLError.connectionClosed)) + + let rowSequence = stream.asyncSequence() + + do { + var counter = 0 + for try await _ in rowSequence { + counter += 1 + } + XCTFail("Expected that an error was thrown before.") + } catch { + XCTAssertEqual(error as? PSQLError, .connectionClosed) + } + } + + func testSucceedingRowContinuationsWorks() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + let rowSequence = stream.asyncSequence() + var rowIterator = rowSequence.makeAsyncIterator() + + DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + } + + let row1 = try await rowIterator.next() + XCTAssertEqual(try row1?.decode(column: 0, as: Int.self), 0) + + DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + stream.receive(completion: .success("SELECT 1")) + } + + let row2 = try await rowIterator.next() + XCTAssertNil(row2) + } + + func testFailingRowContinuationsWorks() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + let rowSequence = stream.asyncSequence() + var rowIterator = rowSequence.makeAsyncIterator() + + DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } + stream.receive(dataRows) + } + + let row1 = try await rowIterator.next() + XCTAssertEqual(try row1?.decode(column: 0, as: Int.self), 0) + + DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + stream.receive(completion: .failure(PSQLError.connectionClosed)) + } + + do { + _ = try await rowIterator.next() + XCTFail("Expected that an error was thrown before.") + } catch { + XCTAssertEqual(error as? PSQLError, .connectionClosed) + } + } + + func testAdaptiveRowBufferShrinksAndGrows() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + let initialDataRows: [DataRow] = (0.. don't ask for more + XCTAssertEqual(dataSource.requestCount, 0) + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + XCTAssertEqual(dataSource.requestCount, 1) + + // if the buffer gets new rows so that it has equal or more than target (the target size + // should be halved), however shrinking is only allowed AFTER the first extra rows were + // received. + let addDataRows1: [DataRow] = [[ByteBuffer(integer: Int64(0))]] + stream.receive(addDataRows1) + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + XCTAssertEqual(dataSource.requestCount, 2) + + // if the buffer gets new rows so that it has equal or more than target (the target size + // should be halved) + let addDataRows2: [DataRow] = [[ByteBuffer(integer: Int64(0))]] + stream.receive(addDataRows2) // this should to target being halved. + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget/2 - 1) { + _ = try await rowIterator.next() // Remove all rows until we are back at target + XCTAssertEqual(dataSource.requestCount, 2) + } + + // if we remove another row we should trigger getting new rows. + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + XCTAssertEqual(dataSource.requestCount, 3) + + // remove all remaining rows... this will trigger a target size double + for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget/2 - 1) { + _ = try await rowIterator.next() // Remove all rows until we are back at target + XCTAssertEqual(dataSource.requestCount, 3) + } + + let fillBufferDataRows: [DataRow] = (0.. don't ask for more + XCTAssertEqual(dataSource.requestCount, 3) + _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more + XCTAssertEqual(dataSource.requestCount, 4) + } + + func testAdaptiveRowShrinksToMin() async throws { + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + let logger = Logger(label: "test") + let dataSource = MockRowDataSource() + let stream = PSQLRowStream( + rowDescription: [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + eventLoop: eventLoop, + rowSource: .stream(dataSource) + ) + promise.succeed(stream) + + var currentTarget = AdaptiveRowBuffer.defaultBufferTarget + + let initialDataRows: [DataRow] = (0.. AdaptiveRowBuffer.defaultBufferMinimum { + // the buffer is filled up to currentTarget at that point, if we remove one row and add + // one row it should shrink + XCTAssertEqual(dataSource.requestCount, expectedRequestCount) + _ = try await rowIterator.next() + expectedRequestCount += 1 + XCTAssertEqual(dataSource.requestCount, expectedRequestCount) + + stream.receive([[ByteBuffer(integer: Int64(1))]]) + let newTarget = currentTarget / 2 + + // consume all messages that are to much. + for _ in 0.. Date: Fri, 18 Feb 2022 00:05:31 +0100 Subject: [PATCH 052/292] Make PSQLDecodingContext generic (#217) --- .../New/Data/Array+PSQLCodable.swift | 7 +++- .../New/Data/Bool+PSQLCodable.swift | 7 +++- .../New/Data/Bytes+PSQLCodable.swift | 14 ++++++-- .../New/Data/Date+PSQLCodable.swift | 7 +++- .../New/Data/Decimal+PSQLCodable.swift | 7 +++- .../New/Data/Float+PSQLCodable.swift | 14 ++++++-- .../New/Data/Int+PSQLCodable.swift | 35 +++++++++++++++--- .../New/Data/JSON+PSQLCodable.swift | 7 +++- .../New/Data/Optional+PSQLCodable.swift | 14 ++++---- .../Data/RawRepresentable+PSQLCodable.swift | 7 +++- .../New/Data/String+PSQLCodable.swift | 7 +++- .../New/Data/UUID+PSQLCodable.swift | 7 +++- Sources/PostgresNIO/New/PSQLCodable.swift | 36 +++++++++---------- Sources/PostgresNIO/New/PSQLRow.swift | 7 +--- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 7 +++- .../New/Extensions/PSQLCoding+TestUtils.swift | 6 ++-- 16 files changed, 135 insertions(+), 54 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index ba89bbb8..d74717f7 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -101,7 +101,12 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { } extension Array: PSQLDecodable where Element: PSQLArrayElement { - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Array { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Array { guard case .binary = format else { // currently we only support decoding arrays in binary format. throw PostgresCastingError.Code.failure diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift index 3d7a6776..d9896efe 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift @@ -9,7 +9,12 @@ extension Bool: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Bool { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { guard type == .bool else { throw PostgresCastingError.Code.typeMismatch } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift index b359f3ca..8c5e96a5 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift @@ -30,7 +30,12 @@ extension ByteBuffer: PSQLCodable { byteBuffer.writeBuffer(©OfSelf) } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { return buffer } } @@ -48,7 +53,12 @@ extension Data: PSQLCodable { byteBuffer.writeBytes(self) } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { return buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! } } diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift index 71201853..05491a61 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift @@ -10,7 +10,12 @@ extension Date: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch type { case .timestamp, .timestamptz: guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { diff --git a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift index 0a683e37..22c4785d 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift @@ -10,7 +10,12 @@ extension Decimal: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .numeric): guard let numeric = PostgresNumeric(buffer: &buffer) else { diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index 0aab376f..a3463c8e 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -9,7 +9,12 @@ extension Float: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { @@ -45,7 +50,12 @@ extension Double: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift index d63bb8eb..49284a8a 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift @@ -9,7 +9,12 @@ extension UInt8: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch type { case .bpchar, .char: guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { @@ -37,7 +42,12 @@ extension Int16: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -68,7 +78,12 @@ extension Int32: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -104,7 +119,12 @@ extension Int64: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -152,7 +172,12 @@ extension Int: PSQLCodable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { diff --git a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift index 1500ce84..47aff5bf 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift @@ -14,7 +14,12 @@ extension PSQLCodable where Self: Codable { .binary } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .jsonb): guard JSONBVersionByte == buffer.readInteger(as: UInt8.self) else { diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index a01d5f15..fef7d9d2 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -3,27 +3,25 @@ import NIOCore extension Optional: PSQLDecodable where Wrapped: PSQLDecodable, Wrapped.DecodableType == Wrapped { typealias DecodableType = Wrapped - static func decode( + static func decode( from byteBuffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, - context: PSQLDecodingContext + context: PostgresDecodingContext ) throws -> Optional { preconditionFailure("This should not be called") } - static func decodeRaw( + static func decodeRaw( from byteBuffer: inout ByteBuffer?, type: PostgresDataType, format: PostgresFormat, - context: PSQLDecodingContext + context: PostgresDecodingContext ) throws -> Self { - switch byteBuffer { - case .some(var buffer): - return try DecodableType.decode(from: &buffer, type: type, format: format, context: context) - case .none: + guard var buffer = byteBuffer else { return nil } + return try DecodableType.decode(from: &buffer, type: type, format: format, context: context) } } diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift index f8812da3..2036fef6 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift @@ -9,7 +9,12 @@ extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { self.rawValue.psqlFormat } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { guard let rawValue = try? RawValue.decode(from: &buffer, type: type, format: format, context: context), let selfValue = Self.init(rawValue: rawValue) else { throw PostgresCastingError.Code.failure diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift index d761fc48..66f4a400 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift @@ -14,7 +14,12 @@ extension String: PSQLCodable { byteBuffer.writeString(self) } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (_, .varchar), (_, .text), diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift index 0fdd2990..b258473b 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift @@ -22,7 +22,12 @@ extension UUID: PSQLCodable { ]) } - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { switch (format, type) { case (.binary, .uuid): guard let uuid = buffer.readUUID() else { diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index fbf3fbbb..4b5dd041 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -33,21 +33,31 @@ protocol PSQLDecodable { /// - context: A `PSQLDecodingContext` providing context for decoding. This includes a `JSONDecoder` /// to use when decoding json and metadata to create better errors. /// - Returns: A decoded object - static func decode(from byteBuffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self + static func decode( + from byteBuffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self /// Decode an entity from the `byteBuffer` in postgres wire format. /// This method has a default implementation and may be overriden /// only for special cases, like `Optional`s. - static func decodeRaw(from byteBuffer: inout ByteBuffer?, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self + static func decodeRaw( + from byteBuffer: inout ByteBuffer?, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self } extension PSQLDecodable { @inlinable - public static func decodeRaw( + static func decodeRaw( from byteBuffer: inout ByteBuffer?, type: PostgresDataType, format: PostgresFormat, - context: PSQLDecodingContext + context: PostgresDecodingContext ) throws -> Self { guard var buffer = byteBuffer else { throw PostgresCastingError.Code.missingData @@ -79,22 +89,10 @@ struct PSQLEncodingContext { let jsonEncoder: PostgresJSONEncoder } -struct PSQLDecodingContext { +struct PostgresDecodingContext { + let jsonDecoder: JSONDecoder - let jsonDecoder: PostgresJSONDecoder - - let columnIndex: Int - let columnName: String - - let file: String - let line: Int - - init(jsonDecoder: PostgresJSONDecoder, columnName: String, columnIndex: Int, file: String, line: Int) { + init(jsonDecoder: JSONDecoder) { self.jsonDecoder = jsonDecoder - self.columnName = columnName - self.columnIndex = columnIndex - - self.file = file - self.line = line } } diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift index dbd57c48..c86f62a1 100644 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -48,12 +48,7 @@ extension PSQLRow { precondition(index < self.data.columnCount) let column = self.columns[index] - let context = PSQLDecodingContext( - jsonDecoder: jsonDecoder, - columnName: column.name, - columnIndex: index, - file: file, - line: line) + let context = PostgresDecodingContext(jsonDecoder: jsonDecoder) // Safe to force unwrap here, as we have ensured above that the row has enough columns var cellSlice = self.data[column: index]! diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index c0f7cef8..84bf7e56 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -26,7 +26,12 @@ extension PostgresData: PSQLEncodable { } extension PostgresData: PSQLDecodable { - static func decode(from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PSQLDecodingContext) throws -> Self { + static func decode( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Self { let myBuffer = buffer.readSlice(length: buffer.readableBytes)! return PostgresData(type: PostgresDataType(UInt32(type.rawValue)), typeModifier: nil, formatCode: .binary, value: myBuffer) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift index 602306d8..3c83ac1f 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift @@ -7,9 +7,9 @@ extension PSQLFrontendMessageEncoder { } } -extension PSQLDecodingContext { - static func forTests(columnName: String = "unknown", columnIndex: Int = 0, jsonDecoder: PostgresJSONDecoder = JSONDecoder(), file: String = #file, line: Int = #line) -> Self { - Self(jsonDecoder: JSONDecoder(), columnName: columnName, columnIndex: columnIndex, file: file, line: line) +extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { + static func forTests() -> Self { + Self(jsonDecoder: JSONDecoder()) } } From 112a5e5dfc729cef0ce18cfd282d83f60eda435c Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 17 Feb 2022 18:26:04 -0600 Subject: [PATCH 053/292] Cut back on number of CI jobs (#218) --- .github/workflows/test.yml | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ccf4b474..79021623 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,8 +9,8 @@ jobs: swiftver: - swift:5.2 - swift:5.3 - - swift:5.4 - swift:5.5 + - swiftlang/swift:nightly-5.6 - swiftlang/swift:nightly-main swiftos: - focal @@ -39,19 +39,15 @@ jobs: dbimage: - postgres:14 - postgres:13 - - postgres:12 - postgres:11 - dbauth: - - trust - - md5 - - scram-sha-256 - swiftver: - # Only test latest Swift for integration tests, issues from older Swift versions that don't show - # up in the unit tests are fairly unlikely. - - swift:5.5 - swiftos: - - focal - container: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} + include: + - dbimage: postgres:14 + dbauth: scram-sha-256 + - dbimage: postgres:13 + dbauth: md5 + - dbimage: postgres:11 + dbauth: trust + container: swift:5.5-focal runs-on: ubuntu-latest env: LOG_LEVEL: debug @@ -119,7 +115,7 @@ jobs: - scram-sha-256 xcode: - latest-stable - - latest + #- latest runs-on: macos-11 env: LOG_LEVEL: debug From 493728618c5aa7e18a1855eff4fcb4110f1ddee9 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Feb 2022 11:03:39 +0100 Subject: [PATCH 054/292] Extend PostgresCastingError (#221) --- Sources/PostgresNIO/New/PSQLError.swift | 39 +++++++++++-------- .../New/PostgresErrorTests.swift | 26 ++++++++++++- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 42dd221e..cdcf86c2 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -80,6 +80,7 @@ struct PSQLError: Error { } } +/// An error that may happen when a ``PostgresRow`` or ``PostgresCell`` is decoded to native Swift types. struct PostgresCastingError: Error, Equatable { struct Code: Hashable, Error { enum Base { @@ -99,31 +100,32 @@ struct PostgresCastingError: Error, Equatable { static let failure = Self.init(.failure) } + /// The casting error code let code: Code - + + /// The cell's column name for which the casting failed let columnName: String + /// The cell's column index for which the casting failed let columnIndex: Int + /// The swift type the cell should have been casted into let targetType: Any.Type + /// The cell's postgres data type for which the casting failed let postgresType: PostgresDataType + /// The cell's postgres format for which the casting failed + let postgresFormat: PostgresFormat + /// A copy of the cell data which was attempted to be casted let postgresData: ByteBuffer? + + /// The file the casting/decoding was attempted in + let file: String + /// The line the casting/decoding was attempted in + let line: Int var description: String { - switch self.code.base { - case .missingData: - return """ - Failed to cast Postgres data type \(self.postgresType.description) to Swift type \(self.targetType) \ - because of missing data. - """ - - case .typeMismatch: - preconditionFailure() - - case .failure: - return """ - Failed to cast Postgres data type \(self.postgresType.description) to Swift type \(self.targetType). - """ - } - + // This may seem very odd... But we are afraid that users might accidentally send the + // unfiltered errors out to end-users. This may leak security relevant information. For this + // reason we overwrite the error description by default to this generic "Database error" + "Database error" } static func ==(lhs: PostgresCastingError, rhs: PostgresCastingError) -> Bool { @@ -132,7 +134,10 @@ struct PostgresCastingError: Error, Equatable { && lhs.columnIndex == rhs.columnIndex && lhs.targetType == rhs.targetType && lhs.postgresType == rhs.postgresType + && lhs.postgresFormat == rhs.postgresFormat && lhs.postgresData == rhs.postgresData + && lhs.file == rhs.file + && lhs.line == rhs.line } } diff --git a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift index 697933ea..79f673c1 100644 --- a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift @@ -9,7 +9,10 @@ final class PostgresCastingErrorTests: XCTestCase { columnIndex: 0, targetType: String.self, postgresType: .text, - postgresData: ByteBuffer(string: "hello world") + postgresFormat: .binary, + postgresData: ByteBuffer(string: "hello world"), + file: "foo.swift", + line: 123 ) let error2 = PostgresCastingError( @@ -18,11 +21,30 @@ final class PostgresCastingErrorTests: XCTestCase { columnIndex: 0, targetType: Int.self, postgresType: .text, - postgresData: ByteBuffer(string: "hello world") + postgresFormat: .binary, + postgresData: ByteBuffer(string: "hello world"), + file: "foo.swift", + line: 123 ) XCTAssertNotEqual(error1, error2) let error3 = error1 XCTAssertEqual(error1, error3) } + + func testPostgresCastingErrorDescription() { + let error = PostgresCastingError( + code: .typeMismatch, + columnName: "column", + columnIndex: 0, + targetType: String.self, + postgresType: .text, + postgresFormat: .binary, + postgresData: ByteBuffer(string: "hello world"), + file: "foo.swift", + line: 123 + ) + + XCTAssertNotEqual("\(error)", "Database error") + } } From 05eaa2ee1c601c8e7d812789e3f54eca6a2a2070 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Feb 2022 11:15:46 +0100 Subject: [PATCH 055/292] Add PostgresCell (#220) --- Sources/PostgresNIO/New/PSQLData.swift | 15 ----- Sources/PostgresNIO/New/PostgresCell.swift | 51 +++++++++++++++++ .../New/PostgresCellTests.swift | 57 +++++++++++++++++++ 3 files changed, 108 insertions(+), 15 deletions(-) delete mode 100644 Sources/PostgresNIO/New/PSQLData.swift create mode 100644 Sources/PostgresNIO/New/PostgresCell.swift create mode 100644 Tests/PostgresNIOTests/New/PostgresCellTests.swift diff --git a/Sources/PostgresNIO/New/PSQLData.swift b/Sources/PostgresNIO/New/PSQLData.swift deleted file mode 100644 index d490c78c..00000000 --- a/Sources/PostgresNIO/New/PSQLData.swift +++ /dev/null @@ -1,15 +0,0 @@ -import NIOCore - -struct PSQLData: Equatable { - - @usableFromInline var bytes: ByteBuffer? - @usableFromInline var dataType: PostgresDataType - @usableFromInline var format: PostgresFormat - - /// use this only for testing - init(bytes: ByteBuffer?, dataType: PostgresDataType, format: PostgresFormat) { - self.bytes = bytes - self.dataType = dataType - self.format = format - } -} diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift new file mode 100644 index 00000000..a461bb37 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -0,0 +1,51 @@ +import NIOCore + +struct PostgresCell: Equatable { + var bytes: ByteBuffer? + var dataType: PostgresDataType + var format: PostgresFormat + + var columnName: String + var columnIndex: Int + + init(bytes: ByteBuffer?, dataType: PostgresDataType, format: PostgresFormat, columnName: String, columnIndex: Int) { + self.bytes = bytes + self.dataType = dataType + self.format = format + + self.columnName = columnName + self.columnIndex = columnIndex + } +} + +extension PostgresCell { + + func decode( + _: T.Type, + context: PostgresDecodingContext, + file: String = #file, + line: Int = #line + ) throws -> T { + var copy = self.bytes + do { + return try T.decodeRaw( + from: ©, + type: self.dataType, + format: self.format, + context: context + ) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: self.columnName, + columnIndex: self.columnIndex, + targetType: T.self, + postgresType: self.dataType, + postgresFormat: self.format, + postgresData: copy, + file: file, + line: line + ) + } + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresCellTests.swift b/Tests/PostgresNIOTests/New/PostgresCellTests.swift new file mode 100644 index 00000000..0693f0b1 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresCellTests.swift @@ -0,0 +1,57 @@ +@testable import PostgresNIO +import XCTest + +final class PostgresCellTests: XCTestCase { + func testDecodingANonOptionalString() { + let cell = PostgresCell( + bytes: ByteBuffer(string: "Hello world"), + dataType: .text, + format: .binary, + columnName: "hello", + columnIndex: 1 + ) + + var result: String? + XCTAssertNoThrow(result = try cell.decode(String.self, context: .forTests())) + XCTAssertEqual(result, "Hello world") + } + + func testDecodingAnOptionalString() { + let cell = PostgresCell( + bytes: nil, + dataType: .text, + format: .binary, + columnName: "hello", + columnIndex: 1 + ) + + var result: String? = "test" + XCTAssertNoThrow(result = try cell.decode(String?.self, context: .forTests())) + XCTAssertNil(result) + } + + func testDecodingFailure() { + let cell = PostgresCell( + bytes: ByteBuffer(string: "Hello world"), + dataType: .text, + format: .binary, + columnName: "hello", + columnIndex: 1 + ) + + XCTAssertThrowsError(try cell.decode(Int?.self, context: .forTests())) { + guard let error = $0 as? PostgresCastingError else { + return XCTFail("Unexpected error") + } + + XCTAssertEqual(error.file, #file) + XCTAssertEqual(error.line, #line - 6) + XCTAssertEqual(error.code, .typeMismatch) + XCTAssertEqual(error.columnName, "hello") + XCTAssertEqual(error.columnIndex, 1) + XCTAssert(error.targetType == Int?.self) + XCTAssertEqual(error.postgresType, .text) + XCTAssertEqual(error.postgresFormat, .binary) + } + } +} From f588870b11edc4766b9c77f2cd67bfd492a7f206 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Feb 2022 11:30:13 +0100 Subject: [PATCH 056/292] Add PSQLRow multi decode (#222) --- .../New/PSQLRow-multi-decode.swift | 1068 +++++++++++++++++ dev/generate-psqlrow-multi-decode.sh | 104 ++ 2 files changed, 1172 insertions(+) create mode 100644 Sources/PostgresNIO/New/PSQLRow-multi-decode.swift create mode 100755 dev/generate-psqlrow-multi-decode.sh diff --git a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift b/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift new file mode 100644 index 00000000..3be5b0c8 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift @@ -0,0 +1,1068 @@ +/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-psqlrow-multi-decode.sh + +extension PSQLRow { + @inlinable + func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0) { + precondition(self.columns.count >= 1) + let columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + let column = columnIterator.next().unsafelyUnwrapped + let swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1) { + precondition(self.columns.count >= 2) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { + precondition(self.columns.count >= 3) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { + precondition(self.columns.count >= 4) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { + precondition(self.columns.count >= 5) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { + precondition(self.columns.count >= 6) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { + precondition(self.columns.count >= 7) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { + precondition(self.columns.count >= 8) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 7 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T7.self + let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6, r7) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { + precondition(self.columns.count >= 9) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 7 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T7.self + let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 8 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T8.self + let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6, r7, r8) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { + precondition(self.columns.count >= 10) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 7 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T7.self + let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 8 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T8.self + let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 9 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T9.self + let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { + precondition(self.columns.count >= 11) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 7 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T7.self + let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 8 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T8.self + let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 9 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T9.self + let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 10 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T10.self + let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { + precondition(self.columns.count >= 12) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 7 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T7.self + let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 8 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T8.self + let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 9 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T9.self + let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 10 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T10.self + let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 11 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T11.self + let r11 = try T11.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { + precondition(self.columns.count >= 13) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 7 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T7.self + let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 8 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T8.self + let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 9 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T9.self + let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 10 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T10.self + let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 11 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T11.self + let r11 = try T11.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 12 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T12.self + let r12 = try T12.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { + precondition(self.columns.count >= 14) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 7 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T7.self + let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 8 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T8.self + let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 9 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T9.self + let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 10 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T10.self + let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 11 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T11.self + let r11 = try T11.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 12 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T12.self + let r12 = try T12.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 13 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T13.self + let r13 = try T13.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { + precondition(self.columns.count >= 15) + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + var column = columnIterator.next().unsafelyUnwrapped + var swiftTargetType: Any.Type = T0.self + + do { + let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 1 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T1.self + let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 2 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T2.self + let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 3 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T3.self + let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 4 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T4.self + let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 5 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T5.self + let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 6 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T6.self + let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 7 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T7.self + let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 8 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T8.self + let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 9 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T9.self + let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 10 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T10.self + let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 11 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T11.self + let r11 = try T11.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 12 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T12.self + let r12 = try T12.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 13 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T13.self + let r13 = try T13.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + columnIndex = 14 + cellData = cellIterator.next().unsafelyUnwrapped + column = columnIterator.next().unsafelyUnwrapped + swiftTargetType = T14.self + let r14 = try T14.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } +} diff --git a/dev/generate-psqlrow-multi-decode.sh b/dev/generate-psqlrow-multi-decode.sh new file mode 100755 index 00000000..e58b17b2 --- /dev/null +++ b/dev/generate-psqlrow-multi-decode.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +set -eu + +here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +function gen() { + how_many=$1 + + if [[ $how_many -ne 1 ]] ; then + echo "" + fi + + echo " @inlinable" + #echo " @_alwaysEmitIntoClient" + echo -n " func decode(_: (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws" + + echo -n " -> (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo ") {" + echo " precondition(self.columns.count >= $how_many)" + #echo " var columnIndex = 0" + if [[ $how_many -eq 1 ]] ; then + echo " let columnIndex = 0" + echo " var cellIterator = self.data.makeIterator()" + echo " var cellData = cellIterator.next().unsafelyUnwrapped" + echo " var columnIterator = self.columns.makeIterator()" + echo " let column = columnIterator.next().unsafelyUnwrapped" + echo " let swiftTargetType: Any.Type = T0.self" + else + echo " var columnIndex = 0" + echo " var cellIterator = self.data.makeIterator()" + echo " var cellData = cellIterator.next().unsafelyUnwrapped" + echo " var columnIterator = self.columns.makeIterator()" + echo " var column = columnIterator.next().unsafelyUnwrapped" + echo " var swiftTargetType: Any.Type = T0.self" + fi + + echo + echo " do {" + echo " let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context)" + echo + for ((n = 1; n<$how_many; n +=1)); do + echo " columnIndex = $n" + echo " cellData = cellIterator.next().unsafelyUnwrapped" + echo " column = columnIterator.next().unsafelyUnwrapped" + echo " swiftTargetType = T$n.self" + echo " let r$n = try T$n.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context)" + echo + done + + echo -n " return (r0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", r$(($n))" + done + echo ")" + echo " } catch let code as PostgresCastingError.Code {" + echo " throw PostgresCastingError(" + echo " code: code," + echo " columnName: column.name," + echo " columnIndex: columnIndex," + echo " targetType: swiftTargetType," + echo " postgresType: column.dataType," + echo " postgresFormat: column.format," + echo " postgresData: cellData," + echo " file: file," + echo " line: line" + echo " )" + echo " }" + echo " }" +} + +grep -q "ByteBuffer" "${BASH_SOURCE[0]}" || { + echo >&2 "ERROR: ${BASH_SOURCE[0]}: file or directory not found (this should be this script)" + exit 1 +} + +{ +cat <<"EOF" +/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-psqlrow-multi-decode.sh +EOF +echo + +echo "extension PSQLRow {" + +# note: +# - widening the inverval below (eg. going from {1..15} to {1..25}) is Semver minor +# - narrowing the interval below is SemVer _MAJOR_! +for n in {1..15}; do + gen "$n" +done +echo "}" +} > "$here/../Sources/PostgresNIO/New/PSQLRow-multi-decode.swift" From dd5b17c6e9404f8d630c36f358cfe637cb486278 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 21 Feb 2022 07:42:30 +0100 Subject: [PATCH 057/292] Add PostgresQuery (#223) --- .../PostgresConnection+Connect.swift | 9 +- .../PostgresConnection+Database.swift | 31 ++++- .../PostgresNIO/Data/PostgresDataType.swift | 2 +- .../New/BufferedMessageEncoder.swift | 10 +- .../ConnectionStateMachine.swift | 18 +-- .../ExtendedQueryStateMachine.swift | 12 +- .../New/Data/Array+PSQLCodable.swift | 5 +- .../New/Data/Bool+PSQLCodable.swift | 5 +- .../New/Data/Bytes+PSQLCodable.swift | 15 ++- .../New/Data/Date+PSQLCodable.swift | 7 +- .../New/Data/Decimal+PSQLCodable.swift | 5 +- .../New/Data/Float+PSQLCodable.swift | 10 +- .../New/Data/Int+PSQLCodable.swift | 25 +++- .../New/Data/JSON+PSQLCodable.swift | 5 +- .../New/Data/Optional+PSQLCodable.swift | 10 +- .../Data/RawRepresentable+PSQLCodable.swift | 5 +- .../New/Data/String+PSQLCodable.swift | 5 +- .../New/Data/UUID+PSQLCodable.swift | 5 +- Sources/PostgresNIO/New/Messages/Bind.swift | 25 ++-- .../PostgresNIO/New/PSQLChannelHandler.swift | 103 +++++++-------- Sources/PostgresNIO/New/PSQLCodable.swift | 22 +++- Sources/PostgresNIO/New/PSQLConnection.swift | 44 ++----- .../PostgresNIO/New/PSQLFrontendMessage.swift | 2 +- .../New/PSQLFrontendMessageEncoder.swift | 10 +- Sources/PostgresNIO/New/PSQLTask.swift | 21 +-- Sources/PostgresNIO/New/PostgresQuery.swift | 120 ++++++++++++++++++ Sources/PostgresNIO/Postgres+PSQLCompat.swift | 12 +- .../PSQLIntegrationTests.swift | 20 +-- .../ConnectionStateMachineTests.swift | 9 +- .../ExtendedQueryStateMachineTests.swift | 22 ++-- .../New/Data/JSON+PSQLCodableTests.swift | 2 +- .../ConnectionAction+TestUtils.swift | 39 +----- .../New/Extensions/PSQLCoding+TestUtils.swift | 10 +- .../PSQLFrontendMessage+Equatable.swift | 82 ------------ .../New/Messages/BindTests.swift | 9 +- .../New/Messages/CancelTests.swift | 4 +- .../New/Messages/CloseTests.swift | 8 +- .../New/Messages/DescribeTests.swift | 8 +- .../New/Messages/ExecuteTests.swift | 4 +- .../New/Messages/ParseTests.swift | 4 +- .../New/Messages/PasswordTests.swift | 4 +- .../Messages/SASLInitialResponseTests.swift | 8 +- .../New/Messages/SASLResponseTests.swift | 8 +- .../New/Messages/SSLRequestTests.swift | 4 +- .../New/Messages/StartupTests.swift | 4 +- .../New/PSQLChannelHandlerTests.swift | 4 +- .../New/PSQLFrontendMessageTests.swift | 12 +- .../New/PSQLRowSequenceTests.swift | 24 ++-- .../New/PSQLRowStreamTests.swift | 14 +- .../New/PostgresQueryTests.swift | 82 ++++++++++++ 50 files changed, 525 insertions(+), 403 deletions(-) create mode 100644 Sources/PostgresNIO/New/PostgresQuery.swift delete mode 100644 Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessage+Equatable.swift create mode 100644 Tests/PostgresNIOTests/New/PostgresQueryTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift index 388cdbc4..bcedd2fb 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift @@ -10,16 +10,11 @@ extension PostgresConnection { logger: Logger = .init(label: "codes.vapor.postgres"), on eventLoop: EventLoop ) -> EventLoopFuture { - - let coders = PSQLConnection.Configuration.Coders( - jsonEncoder: _defaultJSONEncoder - ) - let configuration = PSQLConnection.Configuration( connection: .resolved(address: socketAddress, serverName: serverHostname), authentication: nil, - tlsConfiguration: tlsConfiguration, - coders: coders) + tlsConfiguration: tlsConfiguration + ) return PSQLConnection.connect( configuration: configuration, diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift index 68e6c96c..8b82c1b4 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift @@ -15,7 +15,14 @@ extension PostgresConnection: PostgresDatabase { switch command { case .query(let query, let binds, let onMetadata, let onRow): - resultFuture = self.underlying.query(query, binds, logger: logger).flatMap { stream in + var psqlQuery = PostgresQuery(unsafeSQL: query, binds: .init(capacity: binds.count)) + binds.forEach { + // We can bang the try here as encoding PostgresData does not throw. The throw + // is just an option for the protocol. + try! psqlQuery.appendBinding($0, context: .default) + } + + resultFuture = self.underlying.query(psqlQuery, logger: logger).flatMap { stream in let fields = stream.rowDescription.map { column in PostgresMessage.RowDescription.Field( name: column.name, @@ -34,7 +41,14 @@ extension PostgresConnection: PostgresDatabase { } } case .queryAll(let query, let binds, let onResult): - resultFuture = self.underlying.query(query, binds, logger: logger).flatMap { rows in + var psqlQuery = PostgresQuery(unsafeSQL: query, binds: .init(capacity: binds.count)) + binds.forEach { + // We can bang the try here as encoding PostgresData does not throw. The throw + // is just an option for the protocol. + try! psqlQuery.appendBinding($0, context: .default) + } + + resultFuture = self.underlying.query(psqlQuery, logger: logger).flatMap { rows in let fields = rows.rowDescription.map { column in PostgresMessage.RowDescription.Field( name: column.name, @@ -65,7 +79,18 @@ extension PostgresConnection: PostgresDatabase { request.prepared = PreparedQuery(underlying: $0, database: self) } case .executePreparedStatement(let preparedQuery, let binds, let onRow): - resultFuture = self.underlying.execute(preparedQuery.underlying, binds, logger: logger).flatMap { rows in + var bindings = PostgresBindings() + binds.forEach { data in + try! bindings.append(data, context: .default) + } + + let statement = PSQLExecuteStatement( + name: preparedQuery.underlying.name, + binds: bindings, + rowDescription: preparedQuery.underlying.rowDescription + ) + + resultFuture = self.underlying.execute(statement, logger: logger).flatMap { rows in guard let lookupTable = preparedQuery.lookupTable else { return self.eventLoop.makeSucceededFuture(()) } diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index 1652048b..3daa85c5 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -26,7 +26,7 @@ public typealias PostgresFormatCode = PostgresFormat /// The data type's raw object ID. /// Use `select * from pg_type where oid = ;` to lookup more information. -public struct PostgresDataType: RawRepresentable, Equatable, CustomStringConvertible { +public struct PostgresDataType: RawRepresentable, Hashable, CustomStringConvertible { /// `0` public static let null = PostgresDataType(0) /// `16` diff --git a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift index 9c02871e..9de1443d 100644 --- a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift +++ b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift @@ -1,6 +1,6 @@ import NIOCore -struct BufferedMessageEncoder { +struct BufferedMessageEncoder { private enum State { case flushed case writable @@ -8,14 +8,14 @@ struct BufferedMessageEncoder { private var buffer: ByteBuffer private var state: State = .writable - private var encoder: Encoder + private var encoder: PSQLFrontendMessageEncoder - init(buffer: ByteBuffer, encoder: Encoder) { + init(buffer: ByteBuffer, encoder: PSQLFrontendMessageEncoder) { self.buffer = buffer self.encoder = encoder } - mutating func encode(_ message: Encoder.OutboundIn) throws { + mutating func encode(_ message: PSQLFrontendMessage) { switch self.state { case .flushed: self.state = .writable @@ -25,7 +25,7 @@ struct BufferedMessageEncoder { break } - try self.encoder.encode(data: message, out: &self.buffer) + self.encoder.encode(data: message, out: &self.buffer) } mutating func flush() -> ByteBuffer? { diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 27dd40dc..36bcdf39 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -84,8 +84,8 @@ struct ConnectionStateMachine { // Connection Actions // --- general actions - case sendParseDescribeBindExecuteSync(query: String, binds: [PSQLEncodable]) - case sendBindExecuteSync(statementName: String, binds: [PSQLEncodable]) + case sendParseDescribeBindExecuteSync(PostgresQuery) + case sendBindExecuteSync(PSQLExecuteStatement) case failQuery(ExtendedQueryContext, with: PSQLError, cleanupContext: CleanUpContext?) case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) @@ -1050,11 +1050,11 @@ extension ConnectionStateMachine { } return false - case .decoding(_): + case .decoding: return true - case .unexpectedBackendMessage(_): + case .unexpectedBackendMessage: return true - case .unsupportedAuthMechanism(_): + case .unsupportedAuthMechanism: return true case .authMechanismRequiresPassword: return true @@ -1106,10 +1106,10 @@ extension ConnectionStateMachine { extension ConnectionStateMachine { mutating func modify(with action: ExtendedQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { switch action { - case .sendParseDescribeBindExecuteSync(let query, let binds): - return .sendParseDescribeBindExecuteSync(query: query, binds: binds) - case .sendBindExecuteSync(let statementName, let binds): - return .sendBindExecuteSync(statementName: statementName, binds: binds) + case .sendParseDescribeBindExecuteSync(let query): + return .sendParseDescribeBindExecuteSync(query) + case .sendBindExecuteSync(let executeStatement): + return .sendBindExecuteSync(executeStatement) case .failQuery(let requestContext, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 67fe219f..c778477a 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -23,8 +23,8 @@ struct ExtendedQueryStateMachine { } enum Action { - case sendParseDescribeBindExecuteSync(query: String, binds: [PSQLEncodable]) - case sendBindExecuteSync(statementName: String, binds: [PSQLEncodable]) + case sendParseDescribeBindExecuteSync(PostgresQuery) + case sendBindExecuteSync(PSQLExecuteStatement) // --- general actions case failQuery(ExtendedQueryContext, with: PSQLError) @@ -56,18 +56,18 @@ struct ExtendedQueryStateMachine { case .unnamed(let query): return self.avoidingStateMachineCoW { state -> Action in state = .parseDescribeBindExecuteSyncSent(queryContext) - return .sendParseDescribeBindExecuteSync(query: query, binds: queryContext.bind) + return .sendParseDescribeBindExecuteSync(query) } - case .preparedStatement(let name, let rowDescription): + case .preparedStatement(let prepared): return self.avoidingStateMachineCoW { state -> Action in - switch rowDescription { + switch prepared.rowDescription { case .some(let rowDescription): state = .rowDescriptionReceived(queryContext, rowDescription.columns) case .none: state = .noDataMessageReceived(queryContext) } - return .sendBindExecuteSync(statementName: name, binds: queryContext.bind) + return .sendBindExecuteSync(prepared) } } } diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index d74717f7..14b92050 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -76,7 +76,10 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { .binary } - func encode(into buffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + func encode( + into buffer: inout ByteBuffer, + context: PSQLEncodingContext + ) throws { // 0 if empty, 1 if not buffer.writeInteger(self.isEmpty ? 0 : 1, as: UInt32.self) // b diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift index d9896efe..ce0350a2 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift @@ -49,7 +49,10 @@ extension Bool: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeInteger(self ? 1 : 0, as: UInt8.self) } } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift index 8c5e96a5..d7e0e804 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift @@ -11,7 +11,10 @@ extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeBytes(self) } } @@ -25,7 +28,10 @@ extension ByteBuffer: PSQLCodable { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { var copyOfSelf = self // dirty hack byteBuffer.writeBuffer(©OfSelf) } @@ -49,7 +55,10 @@ extension Data: PSQLCodable { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeBytes(self) } diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift index 05491a61..d8d48915 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift @@ -34,9 +34,12 @@ extension Date: PSQLCodable { } } - func encode(into buffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { let seconds = self.timeIntervalSince(Self._psqlDateStart) * Double(Self._microsecondsPerSecond) - buffer.writeInteger(Int64(seconds)) + byteBuffer.writeInteger(Int64(seconds)) } // MARK: Private Constants diff --git a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift index 22c4785d..aa1569cc 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift @@ -32,7 +32,10 @@ extension Decimal: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { let numeric = PostgresNumeric(decimal: self) byteBuffer.writeInteger(numeric.ndigits) byteBuffer.writeInteger(numeric.weight) diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index a3463c8e..fd5abfb2 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -36,7 +36,10 @@ extension Float: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.psqlWriteFloat(self) } } @@ -77,7 +80,10 @@ extension Double: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.psqlWriteDouble(self) } } diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift index 49284a8a..ca373b78 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift @@ -27,7 +27,10 @@ extension UInt8: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeInteger(self, as: UInt8.self) } } @@ -64,7 +67,10 @@ extension Int16: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeInteger(self, as: Int16.self) } } @@ -105,7 +111,10 @@ extension Int32: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeInteger(self, as: Int32.self) } } @@ -151,7 +160,10 @@ extension Int64: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeInteger(self, as: Int64.self) } } @@ -204,7 +216,10 @@ extension Int: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeInteger(self, as: Int.self) } } diff --git a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift index 47aff5bf..972f11e7 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift @@ -33,7 +33,10 @@ extension PSQLCodable where Self: Codable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) throws { byteBuffer.writeInteger(JSONBVersionByte) try context.jsonEncoder.encode(self, into: &byteBuffer) } diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index fef7d9d2..79ba08af 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -44,11 +44,17 @@ extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") } - func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + func encodeRaw( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) throws { switch self { case .none: byteBuffer.writeInteger(-1, as: Int32.self) diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift index 2036fef6..3a05a848 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift @@ -23,7 +23,10 @@ extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { return selfValue } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) throws { try rawValue.encode(into: &byteBuffer, context: context) } } diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift index 66f4a400..481296cc 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift @@ -10,7 +10,10 @@ extension String: PSQLCodable { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { byteBuffer.writeString(self) } diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift index b258473b..bf5265f3 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift @@ -12,7 +12,10 @@ extension UUID: PSQLCodable { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { let uuid = self.uuid byteBuffer.writeBytes([ uuid.0, uuid.1, uuid.2, uuid.3, diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index eea976c9..74868b4c 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -2,17 +2,17 @@ import NIOCore extension PSQLFrontendMessage { - struct Bind { + struct Bind: PSQLMessagePayloadEncodable, Equatable { /// The name of the destination portal (an empty string selects the unnamed portal). var portalName: String /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). var preparedStatementName: String - + /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. - var parameters: [PSQLEncodable] + var bind: PostgresBindings - func encode(into buffer: inout ByteBuffer, using jsonEncoder: PostgresJSONEncoder) throws { + func encode(into buffer: inout ByteBuffer) { buffer.writeNullTerminatedString(self.portalName) buffer.writeNullTerminatedString(self.preparedStatementName) @@ -20,20 +20,17 @@ extension PSQLFrontendMessage { // zero to indicate that there are no parameters or that the parameters all use the // default format (text); or one, in which case the specified format code is applied // to all parameters; or it can equal the actual number of parameters. - buffer.writeInteger(Int16(self.parameters.count)) + buffer.writeInteger(Int16(self.bind.count)) // The parameter format codes. Each must presently be zero (text) or one (binary). - self.parameters.forEach { - buffer.writeInteger($0.psqlFormat.rawValue) + self.bind.metadata.forEach { + buffer.writeInteger($0.format.rawValue) } - buffer.writeInteger(Int16(self.parameters.count)) - - let context = PSQLEncodingContext(jsonEncoder: jsonEncoder) - - try self.parameters.forEach { parameter in - try parameter.encodeRaw(into: &buffer, context: context) - } + buffer.writeInteger(Int16(self.bind.count)) + + var parametersCopy = self.bind.bytes + buffer.writeBuffer(¶metersCopy) // The number of result-column format codes that follow (denoted R below). This can be // zero to indicate that there are no result columns or that the result columns should diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index c1f3c016..575bf02c 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -25,7 +25,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { private var handlerContext: ChannelHandlerContext! private var rowStream: PSQLRowStream? private var decoder: NIOSingleStepByteToMessageProcessor - private var encoder: BufferedMessageEncoder! + private var encoder: BufferedMessageEncoder! private let configuration: PSQLConnection.Configuration private let configureSSLCallback: ((Channel) throws -> Void)? @@ -64,7 +64,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.handlerContext = context self.encoder = BufferedMessageEncoder( buffer: context.channel.allocator.buffer(capacity: 256), - encoder: PSQLFrontendMessageEncoder(jsonEncoder: self.configuration.coders.jsonEncoder) + encoder: PSQLFrontendMessageEncoder() ) if context.channel.isActive { @@ -228,18 +228,18 @@ final class PSQLChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - try! self.encoder.encode(.startup(.versionThree(parameters: authContext.toStartupParameters()))) + self.encoder.encode(.startup(.versionThree(parameters: authContext.toStartupParameters()))) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .sendSSLRequest: - try! self.encoder.encode(.sslRequest(.init())) + self.encoder.encode(.sslRequest(.init())) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .sendPasswordMessage(let mode, let authContext): self.sendPasswordMessage(mode: mode, authContext: authContext, context: context) case .sendSaslInitialResponse(let name, let initialResponse): - try! self.encoder.encode(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse))) + self.encoder.encode(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse))) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .sendSaslResponse(let bytes): - try! self.encoder.encode(.saslResponse(.init(data: bytes))) + self.encoder.encode(.saslResponse(.init(data: bytes))) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .closeConnectionAndCleanup(let cleanupContext): self.closeConnectionAndCleanup(cleanupContext, context: context) @@ -247,10 +247,10 @@ final class PSQLChannelHandler: ChannelDuplexHandler { context.fireChannelInactive() case .sendParseDescribeSync(let name, let query): self.sendParseDecribeAndSyncMessage(statementName: name, query: query, context: context) - case .sendBindExecuteSync(let statementName, let binds): - self.sendBindExecuteAndSyncMessage(statementName: statementName, binds: binds, context: context) - case .sendParseDescribeBindExecuteSync(let query, let binds): - self.sendParseDescribeBindExecuteAndSyncMessage(query: query, binds: binds, context: context) + case .sendBindExecuteSync(let executeStatement): + self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) + case .sendParseDescribeBindExecuteSync(let query): + self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) case .succeedQuery(let queryContext, columns: let columns): self.succeedQueryWithRowStream(queryContext, columns: columns, context: context) case .succeedQueryNoRowsComming(let queryContext, let commandTag): @@ -303,7 +303,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { // The normal, graceful termination procedure is that the frontend sends a Terminate // message and immediately closes the connection. On receipt of this message, the // backend closes the connection and terminates. - try! self.encoder.encode(.terminate) + self.encoder.encode(.terminate) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } context.close(mode: .all, promise: promise) @@ -368,11 +368,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler { hash2.append(salt.3) let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest() - try! self.encoder.encode(.password(.init(value: hash))) + self.encoder.encode(.password(.init(value: hash))) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .cleartext: - try! self.encoder.encode(.password(.init(value: authContext.password ?? ""))) + self.encoder.encode(.password(.init(value: authContext.password ?? ""))) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } } @@ -380,13 +380,13 @@ final class PSQLChannelHandler: ChannelDuplexHandler { private func sendCloseAndSyncMessage(_ sendClose: CloseTarget, context: ChannelHandlerContext) { switch sendClose { case .preparedStatement(let name): - try! self.encoder.encode(.close(.preparedStatement(name))) - try! self.encoder.encode(.sync) + self.encoder.encode(.close(.preparedStatement(name))) + self.encoder.encode(.sync) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) case .portal(let name): - try! self.encoder.encode(.close(.portal(name))) - try! self.encoder.encode(.sync) + self.encoder.encode(.close(.portal(name))) + self.encoder.encode(.sync) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } } @@ -401,66 +401,49 @@ final class PSQLChannelHandler: ChannelDuplexHandler { preparedStatementName: statementName, query: query, parameters: []) - - - do { - try self.encoder.encode(.parse(parse)) - try self.encoder.encode(.describe(.preparedStatement(statementName))) - try self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) - } catch { - let action = self.state.errorHappened(.channel(underlying: error)) - self.run(action, with: context) - } + + self.encoder.encode(.parse(parse)) + self.encoder.encode(.describe(.preparedStatement(statementName))) + self.encoder.encode(.sync) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } private func sendBindExecuteAndSyncMessage( - statementName: String, - binds: [PSQLEncodable], - context: ChannelHandlerContext) - { + executeStatement: PSQLExecuteStatement, + context: ChannelHandlerContext + ) { let bind = PSQLFrontendMessage.Bind( portalName: "", - preparedStatementName: statementName, - parameters: binds) - - do { - try self.encoder.encode(.bind(bind)) - try self.encoder.encode(.execute(.init(portalName: ""))) - try self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) - } catch { - let action = self.state.errorHappened(.channel(underlying: error)) - self.run(action, with: context) - } + preparedStatementName: executeStatement.name, + bind: executeStatement.binds) + + self.encoder.encode(.bind(bind)) + self.encoder.encode(.execute(.init(portalName: ""))) + self.encoder.encode(.sync) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } private func sendParseDescribeBindExecuteAndSyncMessage( - query: String, binds: [PSQLEncodable], + query: PostgresQuery, context: ChannelHandlerContext) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") let unnamedStatementName = "" let parse = PSQLFrontendMessage.Parse( preparedStatementName: unnamedStatementName, - query: query, - parameters: binds.map { $0.psqlType }) + query: query.sql, + parameters: query.binds.metadata.map(\.dataType)) let bind = PSQLFrontendMessage.Bind( portalName: "", preparedStatementName: unnamedStatementName, - parameters: binds) - - do { - try self.encoder.encode(.parse(parse)) - try self.encoder.encode(.describe(.preparedStatement(""))) - try self.encoder.encode(.bind(bind)) - try self.encoder.encode(.execute(.init(portalName: ""))) - try self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) - } catch { - let action = self.state.errorHappened(.channel(underlying: error)) - self.run(action, with: context) - } + bind: query.binds) + + self.encoder.encode(.parse(parse)) + self.encoder.encode(.describe(.preparedStatement(""))) + self.encoder.encode(.bind(bind)) + self.encoder.encode(.execute(.init(portalName: ""))) + self.encoder.encode(.sync) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) } private func succeedQueryWithRowStream( diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index 4b5dd041..02ed779f 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -1,4 +1,5 @@ import NIOCore +import Foundation /// A type that can encode itself to a postgres wire binary representation. protocol PSQLEncodable { @@ -10,12 +11,12 @@ protocol PSQLEncodable { /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting /// the byte count. This method is called from the default `encodeRaw` implementation. - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws + func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws /// Encode the entity into the `byteBuffer` in Postgres binary format including its /// leading byte count. This method has a default implementation and may be overriden /// only for special cases, like `Optional`s. - func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws + func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws } /// A type that can decode itself from a postgres wire binary representation. @@ -70,7 +71,10 @@ extension PSQLDecodable { protocol PSQLCodable: PSQLEncodable, PSQLDecodable {} extension PSQLEncodable { - func encodeRaw(into buffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + func encodeRaw( + into buffer: inout ByteBuffer, + context: PSQLEncodingContext + ) throws { // The length of the parameter value, in bytes (this count does not include // itself). Can be zero. let lengthIndex = buffer.writerIndex @@ -85,8 +89,16 @@ extension PSQLEncodable { } } -struct PSQLEncodingContext { - let jsonEncoder: PostgresJSONEncoder +struct PSQLEncodingContext { + let jsonEncoder: JSONEncoder + + init(jsonEncoder: JSONEncoder) { + self.jsonEncoder = jsonEncoder + } +} + +extension PSQLEncodingContext where JSONEncoder == Foundation.JSONEncoder { + static let `default` = PSQLEncodingContext(jsonEncoder: JSONEncoder()) } struct PostgresDecodingContext { diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 40b42b11..2ebb2bba 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -12,18 +12,6 @@ final class PSQLConnection { struct Configuration { - struct Coders { - var jsonEncoder: PostgresJSONEncoder - - init(jsonEncoder: PostgresJSONEncoder) { - self.jsonEncoder = jsonEncoder - } - - static var foundation: Coders { - Coders(jsonEncoder: JSONEncoder()) - } - } - struct Authentication { var username: String var database: String? = nil @@ -47,31 +35,26 @@ final class PSQLConnection { var authentication: Authentication? var tlsConfiguration: TLSConfiguration? - var coders: Coders init(host: String, port: Int = 5432, username: String, database: String? = nil, password: String? = nil, - tlsConfiguration: TLSConfiguration? = nil, - coders: Coders = .foundation) - { + tlsConfiguration: TLSConfiguration? = nil + ) { self.connection = .unresolved(host: host, port: port) self.authentication = Authentication(username: username, password: password, database: database) self.tlsConfiguration = tlsConfiguration - self.coders = coders } init(connection: Connection, authentication: Authentication?, - tlsConfiguration: TLSConfiguration?, - coders: Coders = .foundation) - { + tlsConfiguration: TLSConfiguration? + ) { self.connection = connection self.authentication = authentication self.tlsConfiguration = tlsConfiguration - self.coders = coders } } @@ -116,21 +99,17 @@ final class PSQLConnection { } // MARK: Query - - func query(_ query: String, logger: Logger) -> EventLoopFuture { - self.query(query, [], logger: logger) - } - func query(_ query: String, _ bind: [PSQLEncodable], logger: Logger) -> EventLoopFuture { + func query(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.connectionID)" - guard bind.count <= Int(Int16.max) else { + guard query.binds.count <= Int(Int16.max) else { return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) } + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let context = ExtendedQueryContext( query: query, - bind: bind, logger: logger, promise: promise) @@ -155,16 +134,13 @@ final class PSQLConnection { } } - func execute(_ preparedStatement: PSQLPreparedStatement, - _ bind: [PSQLEncodable], logger: Logger) -> EventLoopFuture - { - guard bind.count <= Int(Int16.max) else { + func execute(_ executeStatement: PSQLExecuteStatement, logger: Logger) -> EventLoopFuture { + guard executeStatement.binds.count <= Int(Int16.max) else { return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) } let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let context = ExtendedQueryContext( - preparedStatement: preparedStatement, - bind: bind, + executeStatement: executeStatement, logger: logger, promise: promise) diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift b/Sources/PostgresNIO/New/PSQLFrontendMessage.swift index 56e94ff0..1a3cb28d 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessage.swift @@ -4,7 +4,7 @@ import NIOCore /// /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) -enum PSQLFrontendMessage { +enum PSQLFrontendMessage: Equatable { case bind(Bind) case cancel(Cancel) case close(Close) diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift index ea016970..92ffeb07 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift @@ -2,19 +2,15 @@ struct PSQLFrontendMessageEncoder: MessageToByteEncoder { typealias OutboundIn = PSQLFrontendMessage - let jsonEncoder: PostgresJSONEncoder + init() {} - init(jsonEncoder: PostgresJSONEncoder) { - self.jsonEncoder = jsonEncoder - } - - func encode(data message: PSQLFrontendMessage, out buffer: inout ByteBuffer) throws { + func encode(data message: PSQLFrontendMessage, out buffer: inout ByteBuffer) { switch message { case .bind(let bind): buffer.writeInteger(message.id.rawValue) let startIndex = buffer.writerIndex buffer.writeInteger(Int32(0)) // placeholder for length - try bind.encode(into: &buffer, using: self.jsonEncoder) + bind.encode(into: &buffer) let length = Int32(buffer.writerIndex - startIndex) buffer.setInteger(length, at: startIndex) diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 0f0c6d04..f9ca1232 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -20,40 +20,32 @@ enum PSQLTask { final class ExtendedQueryContext { enum Query { - case unnamed(String) - case preparedStatement(name: String, rowDescription: RowDescription?) + case unnamed(PostgresQuery) + case preparedStatement(PSQLExecuteStatement) } let query: Query - let bind: [PSQLEncodable] let logger: Logger - + let promise: EventLoopPromise - init(query: String, - bind: [PSQLEncodable], + init(query: PostgresQuery, logger: Logger, promise: EventLoopPromise) { self.query = .unnamed(query) - self.bind = bind self.logger = logger self.promise = promise } - init(preparedStatement: PSQLPreparedStatement, - bind: [PSQLEncodable], + init(executeStatement: PSQLExecuteStatement, logger: Logger, promise: EventLoopPromise) { - self.query = .preparedStatement( - name: preparedStatement.name, - rowDescription: preparedStatement.rowDescription) - self.bind = bind + self.query = .preparedStatement(executeStatement) self.logger = logger self.promise = promise } - } final class PrepareStatementContext { @@ -75,7 +67,6 @@ final class PrepareStatementContext { } final class CloseCommandContext { - let target: CloseTarget let logger: Logger let promise: EventLoopPromise diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift new file mode 100644 index 00000000..7c748e83 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -0,0 +1,120 @@ +struct PostgresQuery: Hashable { + /// The query string + var sql: String + /// The query binds + var binds: PostgresBindings + + init(unsafeSQL sql: String, binds: PostgresBindings = PostgresBindings()) { + self.sql = sql + self.binds = binds + } +} + +extension PostgresQuery: ExpressibleByStringInterpolation { + typealias StringInterpolation = Interpolation + + init(stringInterpolation: Interpolation) { + self.sql = stringInterpolation.sql + self.binds = stringInterpolation.binds + } + + init(stringLiteral value: String) { + self.sql = value + self.binds = PostgresBindings() + } + + mutating func appendBinding( + _ value: Value, + context: PSQLEncodingContext + ) throws { + try self.binds.append(value, context: context) + } +} + +extension PostgresQuery { + struct Interpolation: StringInterpolationProtocol { + typealias StringLiteralType = String + + var sql: String + var binds: PostgresBindings + + init(literalCapacity: Int, interpolationCount: Int) { + self.sql = "" + self.binds = PostgresBindings(capacity: interpolationCount) + } + + mutating func appendLiteral(_ literal: String) { + self.sql.append(contentsOf: literal) + } + + mutating func appendInterpolation(_ value: Value) throws { + try self.binds.append(value, context: .default) + self.sql.append(contentsOf: "$\(self.binds.count)") + } + + mutating func appendInterpolation(_ value: Optional) throws { + try self.binds.append(value, context: .default) + self.sql.append(contentsOf: "$\(self.binds.count)") + } + + mutating func appendInterpolation( + _ value: Value, + context: PSQLEncodingContext + ) throws { + try self.binds.append(value, context: context) + self.sql.append(contentsOf: "$\(self.binds.count)") + } + } +} + +struct PSQLExecuteStatement { + /// The statements name + var name: String + /// The binds + var binds: PostgresBindings + + var rowDescription: RowDescription? +} + +struct PostgresBindings: Hashable { + struct Metadata: Hashable { + var dataType: PostgresDataType + var format: PostgresFormat + + init(dataType: PostgresDataType, format: PostgresFormat) { + self.dataType = dataType + self.format = format + } + + init(value: Value) { + self.init(dataType: value.psqlType, format: value.psqlFormat) + } + } + + var metadata: [Metadata] + var bytes: ByteBuffer + + var count: Int { + self.metadata.count + } + + init() { + self.metadata = [] + self.bytes = ByteBuffer() + } + + init(capacity: Int) { + self.metadata = [] + self.metadata.reserveCapacity(capacity) + self.bytes = ByteBuffer() + self.bytes.reserveCapacity(128 * capacity) + } + + mutating func append( + _ value: Value, + context: PSQLEncodingContext + ) throws { + try value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value)) + } +} diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 84bf7e56..acc8d735 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -8,13 +8,19 @@ extension PostgresData: PSQLEncodable { var psqlFormat: PostgresFormat { .binary } - - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + + func encode( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) throws { preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") } // encoding - func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + func encodeRaw( + into byteBuffer: inout ByteBuffer, + context: PSQLEncodingContext + ) { switch self.value { case .none: byteBuffer.writeInteger(-1, as: Int32.self) diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index f3d63add..43e1e25a 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -121,7 +121,7 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query("SELECT $1::TEXT as foo", ["hello"], logger: .psqlTest).wait()) + XCTAssertNoThrow(stream = try conn?.query("SELECT \("hello")::TEXT as foo", logger: .psqlTest).wait()) var rows: [PSQLRow]? XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var foo: String? @@ -179,7 +179,7 @@ final class IntegrationTests: XCTestCase { var stream: PSQLRowStream? let array: [Int64] = [1, 2, 3] - XCTAssertNoThrow(stream = try conn?.query("SELECT $1::int8[] as array", [array], logger: .psqlTest).wait()) + XCTAssertNoThrow(stream = try conn?.query("SELECT \(array)::int8[] as array", logger: .psqlTest).wait()) var rows: [PSQLRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) @@ -216,7 +216,7 @@ final class IntegrationTests: XCTestCase { var stream: PSQLRowStream? let doubles: [Double] = [3.14, 42] - XCTAssertNoThrow(stream = try conn?.query("SELECT $1::double precision[] as doubles", [doubles], logger: .psqlTest).wait()) + XCTAssertNoThrow(stream = try conn?.query("SELECT \(doubles)::double precision[] as doubles", logger: .psqlTest).wait()) var rows: [PSQLRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) @@ -263,9 +263,9 @@ final class IntegrationTests: XCTestCase { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" SELECT - $1::numeric as numeric, - $2::numeric as numeric_negative - """, [Decimal(string: "123456.789123")!, Decimal(string: "-123456.789123")!], logger: .psqlTest).wait()) + \(Decimal(string: "123456.789123")!)::numeric as numeric, + \(Decimal(string: "-123456.789123")!)::numeric as numeric_negative + """, logger: .psqlTest).wait()) var rows: [PSQLRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) @@ -314,8 +314,8 @@ final class IntegrationTests: XCTestCase { do { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" - select $1::jsonb as jsonb - """, [Object(foo: 1, bar: 2)], logger: .psqlTest).wait()) + select \(Object(foo: 1, bar: 2))::jsonb as jsonb + """, logger: .psqlTest).wait()) var rows: [PSQLRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) @@ -329,8 +329,8 @@ final class IntegrationTests: XCTestCase { do { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" - select $1::json as json - """, [Object(foo: 1, bar: 2)], logger: .psqlTest).wait()) + select \(Object(foo: 1, bar: 2))::json as json + """, logger: .psqlTest).wait()) var rows: [PSQLRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 79dc27c4..63d40e1a 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -33,7 +33,7 @@ class ConnectionStateMachineTests: XCTestCase { var state = ConnectionStateMachine() XCTAssertEqual(state.connected(requireTLS: true), .sendSSLRequest) XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection) - let failError: PSQLError = .failedToAddSSLHandler(underlying: SSLHandlerAddError()) + let failError = PSQLError.failedToAddSSLHandler(underlying: SSLHandlerAddError()) XCTAssertEqual(state.errorHappened(failError), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) } @@ -42,7 +42,7 @@ class ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.connected(requireTLS: true), .sendSSLRequest) XCTAssertEqual(state.sslUnsupportedReceived(), - .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .sslUnsupported, closePromise: nil))) + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.sslUnsupported, closePromise: nil))) } func testParameterStatusReceivedAndBackendKeyAfterAuthenticated() { @@ -100,14 +100,14 @@ class ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) XCTAssertEqual(state.readyForQueryReceived(.idle), - .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.readyForQuery(.idle)), closePromise: nil))) + .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.unexpectedBackendMessage(.readyForQuery(.idle)), closePromise: nil))) } func testErrorIsIgnoredWhenClosingConnection() { // test ignore unclean shutdown when closing connection var stateIgnoreChannelError = ConnectionStateMachine(.closing) - XCTAssertEqual(stateIgnoreChannelError.errorHappened(.channel(underlying: NIOSSLError.uncleanShutdown)), .wait) + XCTAssertEqual(stateIgnoreChannelError.errorHappened(PSQLError.channel(underlying: NIOSSLError.uncleanShutdown)), .wait) XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive) // test ignore any other error when closing connection @@ -129,7 +129,6 @@ class ConnectionStateMachineTests: XCTestCase { var state = ConnectionStateMachine() let extendedQueryContext = ExtendedQueryContext( query: "Select version()", - bind: [], logger: .psqlTest, promise: queryPromise) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index e3a3e515..b5055929 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -12,10 +12,10 @@ class ExtendedQueryStateMachineTests: XCTestCase { let logger = Logger.psqlTest let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - let query = "DELETE FROM table WHERE id=$0" - let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, promise: promise) + let query: PostgresQuery = try! "DELETE FROM table WHERE id=\(1)" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [1])) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), .wait) @@ -28,12 +28,12 @@ class ExtendedQueryStateMachineTests: XCTestCase { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest - let queryPromise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) - queryPromise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - let query = "SELECT version()" - let queryContext = ExtendedQueryContext(query: query, bind: [], logger: logger, promise: queryPromise) + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [])) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -84,10 +84,10 @@ class ExtendedQueryStateMachineTests: XCTestCase { let logger = Logger.psqlTest let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - let query = "DELETE FROM table WHERE id=$0" - let queryContext = ExtendedQueryContext(query: query, bind: [1], logger: logger, promise: promise) + let query: PostgresQuery = try! "DELETE FROM table WHERE id=\(1)" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query: query, binds: [1])) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 40bf3f34..d17b139c 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -84,7 +84,7 @@ class JSON_PSQLCodableTests: XCTestCase { let hello = Hello(name: "world") let encoder = TestEncoder() var buffer = ByteBuffer() - XCTAssertNoThrow(try hello.encode(into: &buffer, context: .forTests(jsonEncoder: encoder))) + XCTAssertNoThrow(try hello.encode(into: &buffer, context: .init(jsonEncoder: encoder))) XCTAssertEqual(encoder.encodeHits, 1) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index c88d112f..6db93101 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -21,43 +21,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { return lhs == rhs case (.sendPasswordMessage(let lhsMethod, let lhsAuthContext), sendPasswordMessage(let rhsMethod, let rhsAuthContext)): return lhsMethod == rhsMethod && lhsAuthContext == rhsAuthContext - case (.sendParseDescribeBindExecuteSync(let lquery, let lbinds), sendParseDescribeBindExecuteSync(let rquery, let rbinds)): - guard lquery == rquery else { - return false - } - - guard lbinds.count == rbinds.count else { - return false - } - - var lhsIterator = lbinds.makeIterator() - var rhsIterator = rbinds.makeIterator() - - for _ in 0.. Self { Self(jsonDecoder: JSONDecoder()) } } -extension PSQLEncodingContext { - static func forTests(jsonEncoder: PostgresJSONEncoder = JSONEncoder()) -> Self { +extension PSQLEncodingContext where JSONEncoder == Foundation.JSONEncoder { + static func forTests(jsonEncoder: JSONEncoder = JSONEncoder()) -> Self { Self(jsonEncoder: jsonEncoder) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessage+Equatable.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessage+Equatable.swift deleted file mode 100644 index 36453b7c..00000000 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessage+Equatable.swift +++ /dev/null @@ -1,82 +0,0 @@ -import NIOCore -@testable import PostgresNIO -import class Foundation.JSONEncoder -import class Foundation.JSONDecoder - -extension PSQLFrontendMessage.Bind: Equatable { - public static func ==(lhs: Self, rhs: Self) -> Bool { - guard lhs.preparedStatementName == rhs.preparedStatementName else { - return false - } - - guard lhs.portalName == rhs.portalName else { - return false - } - - guard lhs.parameters.count == rhs.parameters.count else { - return false - } - - var lhsIterator = lhs.parameters.makeIterator() - var rhsIterator = rhs.parameters.makeIterator() - - do { - while let lhs = lhsIterator.next(), let rhs = rhsIterator.next() { - guard lhs.psqlType == rhs.psqlType else { - return false - } - - var lhsBuffer = ByteBuffer() - var rhsBuffer = ByteBuffer() - - try lhs.encode(into: &lhsBuffer, context: .forTests()) - try rhs.encode(into: &rhsBuffer, context: .forTests()) - - guard lhsBuffer == rhsBuffer else { - return false - } - } - - return true - } catch { - return false - } - } -} - -extension PSQLFrontendMessage: Equatable { - public static func ==(lhs: Self, rhs: Self) -> Bool { - switch (lhs, rhs) { - case (.bind(let lhs), .bind(let rhs)): - return lhs == rhs - case (.cancel(let lhs), .cancel(let rhs)): - return lhs == rhs - case (.close(let lhs), .close(let rhs)): - return lhs == rhs - case (.describe(let lhs), .describe(let rhs)): - return lhs == rhs - case (.execute(let lhs), .execute(let rhs)): - return lhs == rhs - case (.flush, .flush): - return true - case (.parse(let lhs), .parse(let rhs)): - return lhs == rhs - case (.password(let lhs), .password(let rhs)): - return lhs == rhs - case (.saslInitialResponse(let lhs), .saslInitialResponse(let rhs)): - return lhs == rhs - case (.saslResponse(let lhs), .saslResponse(let rhs)): - return lhs == rhs - case (.sslRequest(let lhs), .sslRequest(let rhs)): - return lhs == rhs - case (.sync, .sync): - return true - case (.startup(let lhs), .startup(let rhs)): - return lhs == rhs - case (.terminate, .terminate): - return lhs == rhs - default: - return false - } - } -} diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 7a688d41..285d00ca 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -5,11 +5,14 @@ import NIOCore class BindTests: XCTestCase { func testEncodeBind() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() + var bindings = PostgresBindings() + XCTAssertNoThrow(try bindings.append("Hello", context: .default)) + XCTAssertNoThrow(try bindings.append("World", context: .default)) var byteBuffer = ByteBuffer() - let bind = PSQLFrontendMessage.Bind(portalName: "", preparedStatementName: "", parameters: ["Hello", "World"]) + let bind = PSQLFrontendMessage.Bind(portalName: "", preparedStatementName: "", bind: bindings) let message = PSQLFrontendMessage.bind(bind) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 37) XCTAssertEqual(PSQLFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift index 551e5769..a1626538 100644 --- a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -5,11 +5,11 @@ import NIOCore class CancelTests: XCTestCase { func testEncodeCancel() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let cancel = PSQLFrontendMessage.Cancel(processID: 1234, secretKey: 4567) let message = PSQLFrontendMessage.cancel(cancel) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 16) XCTAssertEqual(16, byteBuffer.readInteger(as: Int32.self)) // payload length diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index 4df15896..d9edf95b 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -5,10 +5,10 @@ import NIOCore class CloseTests: XCTestCase { func testEncodeClosePortal() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.close(.portal("Hello")) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 12) XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) @@ -19,10 +19,10 @@ class CloseTests: XCTestCase { } func testEncodeCloseUnnamedStatement() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.close(.preparedStatement("")) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 7) XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index 87f7d09b..752a3d0f 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -5,10 +5,10 @@ import NIOCore class DescribeTests: XCTestCase { func testEncodeDescribePortal() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.describe(.portal("Hello")) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 12) XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) @@ -19,10 +19,10 @@ class DescribeTests: XCTestCase { } func testEncodeDescribeUnnamedStatement() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.describe(.preparedStatement("")) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 7) XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 3ce8d63d..9fdf06a7 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -5,10 +5,10 @@ import NIOCore class ExecuteTests: XCTestCase { func testEncodeExecute() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let message = PSQLFrontendMessage.execute(.init(portalName: "", maxNumberOfRows: 0)) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) XCTAssertEqual(PSQLFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index edf3f48d..64654153 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -5,14 +5,14 @@ import NIOCore class ParseTests: XCTestCase { func testEncode() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let parse = PSQLFrontendMessage.Parse( preparedStatementName: "test", query: "SELECT version()", parameters: [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray]) let message = PSQLFrontendMessage.parse(parse) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 + (parse.preparedStatementName.count + 1) + (parse.query.count + 1) + 2 + parse.parameters.count * 4 diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 73c464f3..492d2723 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -5,11 +5,11 @@ import NIOCore class PasswordTests: XCTestCase { func testEncodePassword() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() // md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4 let message = PSQLFrontendMessage.password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a")) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) let expectedLength = 41 // 1 (id) + 4 (length) + 35 (string) + 1 (null termination) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index af2459ac..8ad83134 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -5,12 +5,12 @@ import NIOCore class SASLInitialResponseTests: XCTestCase { func testEncodeWithData() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let sasl = PSQLFrontendMessage.SASLInitialResponse( saslMechanism: "hello", initialData: [0, 1, 2, 3, 4, 5, 6, 7]) let message = PSQLFrontendMessage.saslInitialResponse(sasl) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 + (sasl.saslMechanism.count + 1) + 4 + sasl.initialData.count @@ -30,12 +30,12 @@ class SASLInitialResponseTests: XCTestCase { } func testEncodeWithoutData() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let sasl = PSQLFrontendMessage.SASLInitialResponse( saslMechanism: "hello", initialData: []) let message = PSQLFrontendMessage.saslInitialResponse(sasl) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 + (sasl.saslMechanism.count + 1) + 4 + sasl.initialData.count diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index aeb4448a..2b528ff4 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -5,11 +5,11 @@ import NIOCore class SASLResponseTests: XCTestCase { func testEncodeWithData() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let sasl = PSQLFrontendMessage.SASLResponse(data: [0, 1, 2, 3, 4, 5, 6, 7]) let message = PSQLFrontendMessage.saslResponse(sasl) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 + (sasl.data.count) @@ -21,11 +21,11 @@ class SASLResponseTests: XCTestCase { } func testEncodeWithoutData() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let sasl = PSQLFrontendMessage.SASLResponse(data: []) let message = PSQLFrontendMessage.saslResponse(sasl) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 diff --git a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift index bf7cac41..1cc72bb1 100644 --- a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift @@ -5,11 +5,11 @@ import NIOCore class SSLRequestTests: XCTestCase { func testSSLRequest() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let request = PSQLFrontendMessage.SSLRequest() let message = PSQLFrontendMessage.sslRequest(request) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) let byteBufferLength = Int32(byteBuffer.readableBytes) XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 1224aede..913d02ef 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -5,7 +5,7 @@ import NIOCore class StartupTests: XCTestCase { func testStartupMessage() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() let replicationValues: [PSQLFrontendMessage.Startup.Parameters.Replication] = [ @@ -24,7 +24,7 @@ class StartupTests: XCTestCase { let startup = PSQLFrontendMessage.Startup.versionThree(parameters: parameters) let message = PSQLFrontendMessage.startup(startup) - XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer)) + encoder.encode(data: message, out: &byteBuffer) let byteBufferLength = Int32(byteBuffer.readableBytes) XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index a9bfb228..f47a0071 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -181,8 +181,8 @@ class PSQLChannelHandlerTests: XCTestCase { username: username, database: database, password: password, - tlsConfiguration: tlsConfiguration, - coders: .foundation) + tlsConfiguration: tlsConfiguration + ) } } diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift index 83b41392..7a8d56eb 100644 --- a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -23,9 +23,9 @@ class PSQLFrontendMessageTests: XCTestCase { // MARK: Encoder func testEncodeFlush() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - XCTAssertNoThrow(try encoder.encode(data: .flush, out: &byteBuffer)) + encoder.encode(data: .flush, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PSQLFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self)) @@ -33,9 +33,9 @@ class PSQLFrontendMessageTests: XCTestCase { } func testEncodeSync() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - XCTAssertNoThrow(try encoder.encode(data: .sync, out: &byteBuffer)) + encoder.encode(data: .sync, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PSQLFrontendMessage.ID.sync.rawValue, byteBuffer.readInteger(as: UInt8.self)) @@ -43,9 +43,9 @@ class PSQLFrontendMessageTests: XCTestCase { } func testEncodeTerminate() { - let encoder = PSQLFrontendMessageEncoder.forTests + let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - XCTAssertNoThrow(try encoder.encode(data: .terminate, out: &byteBuffer)) + encoder.encode(data: .terminate, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PSQLFrontendMessage.ID.terminate.rawValue, byteBuffer.readInteger(as: UInt8.self)) diff --git a/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift index d3dd9665..0dd935a6 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift @@ -16,7 +16,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -46,7 +46,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -79,7 +79,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -106,7 +106,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -128,7 +128,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -157,7 +157,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -187,7 +187,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -217,7 +217,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -251,7 +251,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -289,7 +289,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -354,7 +354,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) @@ -409,7 +409,7 @@ final class PSQLRowSequenceTests: XCTestCase { rowDescription: [ .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) ], - queryContext: .init(query: "SELECT * FROM foo", bind: [], logger: logger, promise: promise), + queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), eventLoop: eventLoop, rowSource: .stream(dataSource) ) diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index abbfce14..dbf506fa 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -10,7 +10,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "INSERT INTO foo bar;", bind: [], logger: logger, promise: promise + query: "INSERT INTO foo bar;", logger: logger, promise: promise ) let stream = PSQLRowStream( @@ -31,7 +31,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise + query: "SELECT * FROM test;", logger: logger, promise: promise ) let stream = PSQLRowStream( @@ -53,7 +53,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise + query: "SELECT * FROM test;", logger: logger, promise: promise ) let dataSource = CountingDataSource() @@ -92,7 +92,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise) + query: "SELECT * FROM test;", logger: logger, promise: promise) let dataSource = CountingDataSource() let stream = PSQLRowStream( rowDescription: [ @@ -142,7 +142,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise + query: "SELECT * FROM test;", logger: logger, promise: promise ) let dataSource = CountingDataSource() @@ -186,7 +186,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise + query: "SELECT * FROM test;", logger: logger, promise: promise ) let dataSource = CountingDataSource() @@ -235,7 +235,7 @@ class PSQLRowStreamTests: XCTestCase { let promise = eventLoop.makePromise(of: PSQLRowStream.self) let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", bind: [], logger: logger, promise: promise + query: "SELECT * FROM test;", logger: logger, promise: promise ) let dataSource = CountingDataSource() diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift new file mode 100644 index 00000000..24123a54 --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -0,0 +1,82 @@ +@testable import PostgresNIO +import XCTest + +final class PostgresQueryTests: XCTestCase { + + func testStringInterpolationWithOptional() throws { + let string = "Hello World" + let null: UUID? = nil + let uuid: UUID? = UUID() + + let query: PostgresQuery = try """ + INSERT INTO foo (id, title, something) SET (\(uuid), \(string), \(null)); + """ + + XCTAssertEqual(query.sql, "INSERT INTO foo (id, title, something) SET ($1, $2, $3);") + + var expected = ByteBuffer() + expected.writeInteger(Int32(16)) + expected.writeBytes([ + uuid!.uuid.0, uuid!.uuid.1, uuid!.uuid.2, uuid!.uuid.3, + uuid!.uuid.4, uuid!.uuid.5, uuid!.uuid.6, uuid!.uuid.7, + uuid!.uuid.8, uuid!.uuid.9, uuid!.uuid.10, uuid!.uuid.11, + uuid!.uuid.12, uuid!.uuid.13, uuid!.uuid.14, uuid!.uuid.15, + ]) + + expected.writeInteger(Int32(string.utf8.count)) + expected.writeString(string) + expected.writeInteger(Int32(-1)) + + XCTAssertEqual(query.binds.bytes, expected) + } + + func testStringInterpolationWithCustomJSONEncoder() throws { + struct Foo: Codable, PSQLCodable { + var helloWorld: String + } + + let jsonEncoder = JSONEncoder() + jsonEncoder.keyEncodingStrategy = .convertToSnakeCase + + let query: PostgresQuery = try """ + INSERT INTO test (foo) SET (\(Foo(helloWorld: "bar"), context: .init(jsonEncoder: jsonEncoder))); + """ + + XCTAssertEqual(query.sql, "INSERT INTO test (foo) SET ($1);") + + let expectedJSON = #"{"hello_world":"bar"}"# + + var expected = ByteBuffer() + expected.writeInteger(Int32(expectedJSON.utf8.count + 1)) + expected.writeInteger(UInt8(0x01)) + expected.writeString(expectedJSON) + + XCTAssertEqual(query.binds.bytes, expected) + } + + func testAllowUsersToGenerateLotsOfRows() throws { + struct Foo: Codable, PSQLCodable { + var helloWorld: String + } + + let jsonEncoder = JSONEncoder() + jsonEncoder.keyEncodingStrategy = .convertToSnakeCase + + let sql = "INSERT INTO test (id) SET (\((1...5).map({"$\($0)"}).joined(separator: ", ")));" + + var query = PostgresQuery(unsafeSQL: sql, binds: .init(capacity: 5)) + for value in 1...5 { + XCTAssertNoThrow(try query.appendBinding(Int(value), context: .default)) + } + + XCTAssertEqual(query.sql, "INSERT INTO test (id) SET ($1, $2, $3, $4, $5);") + + var expected = ByteBuffer() + for value in 1...5 { + expected.writeInteger(UInt32(8)) + expected.writeInteger(value) + } + + XCTAssertEqual(query.binds.bytes, expected) + } +} From 77eb6c75ec5d77e6869f500b00a654d68f9d59a0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 21 Feb 2022 08:18:25 +0100 Subject: [PATCH 058/292] Rename PSQLDecodable to PostgresDecodable (#224) --- .../New/Data/Array+PSQLCodable.swift | 2 +- .../New/Data/Optional+PSQLCodable.swift | 2 +- Sources/PostgresNIO/New/PSQLCodable.swift | 8 ++--- .../New/PSQLRow-multi-decode.swift | 30 +++++++++---------- Sources/PostgresNIO/New/PSQLRow.swift | 12 ++++---- Sources/PostgresNIO/New/PostgresCell.swift | 2 +- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 2 +- dev/generate-psqlrow-multi-decode.sh | 4 +-- 8 files changed, 31 insertions(+), 31 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index 14b92050..fc8e3b3a 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -103,7 +103,7 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { } } -extension Array: PSQLDecodable where Element: PSQLArrayElement { +extension Array: PostgresDecodable where Element: PSQLArrayElement { static func decode( from buffer: inout ByteBuffer, type: PostgresDataType, diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index 79ba08af..7ab857e0 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -1,6 +1,6 @@ import NIOCore -extension Optional: PSQLDecodable where Wrapped: PSQLDecodable, Wrapped.DecodableType == Wrapped { +extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped.DecodableType == Wrapped { typealias DecodableType = Wrapped static func decode( diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index 02ed779f..cc302f20 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -20,8 +20,8 @@ protocol PSQLEncodable { } /// A type that can decode itself from a postgres wire binary representation. -protocol PSQLDecodable { - associatedtype DecodableType: PSQLDecodable = Self +protocol PostgresDecodable { + associatedtype DecodableType: PostgresDecodable = Self /// Decode an entity from the `byteBuffer` in postgres wire format /// @@ -52,7 +52,7 @@ protocol PSQLDecodable { ) throws -> Self } -extension PSQLDecodable { +extension PostgresDecodable { @inlinable static func decodeRaw( from byteBuffer: inout ByteBuffer?, @@ -68,7 +68,7 @@ extension PSQLDecodable { } /// A type that can be encoded into and decoded from a postgres binary format -protocol PSQLCodable: PSQLEncodable, PSQLDecodable {} +protocol PSQLCodable: PSQLEncodable, PostgresDecodable {} extension PSQLEncodable { func encodeRaw( diff --git a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift b/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift index 3be5b0c8..26eeb167 100644 --- a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift @@ -2,7 +2,7 @@ extension PSQLRow { @inlinable - func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0) { + func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0) { precondition(self.columns.count >= 1) let columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -31,7 +31,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1) { + func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1) { precondition(self.columns.count >= 2) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -66,7 +66,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { + func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { precondition(self.columns.count >= 3) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -107,7 +107,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { + func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { precondition(self.columns.count >= 4) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -154,7 +154,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { + func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { precondition(self.columns.count >= 5) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -207,7 +207,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { + func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { precondition(self.columns.count >= 6) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -266,7 +266,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { precondition(self.columns.count >= 7) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -331,7 +331,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { precondition(self.columns.count >= 8) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -402,7 +402,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { precondition(self.columns.count >= 9) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -479,7 +479,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { precondition(self.columns.count >= 10) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -562,7 +562,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { precondition(self.columns.count >= 11) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -651,7 +651,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { precondition(self.columns.count >= 12) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -746,7 +746,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { precondition(self.columns.count >= 13) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -847,7 +847,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { precondition(self.columns.count >= 14) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -954,7 +954,7 @@ extension PSQLRow { } @inlinable - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { precondition(self.columns.count >= 15) var columnIndex = 0 var cellIterator = self.data.makeIterator() diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift index c86f62a1..91389538 100644 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -27,9 +27,9 @@ extension PSQLRow { /// - Parameters: /// - column: The column name to read the data from /// - type: The type to decode the data into - /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. + /// - Throws: The error of the decoding implementation. See also `PostgresDecodable` protocol for this. /// - Returns: The decoded value of Type T. - func decode(column: String, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { + func decode(column: String, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { guard let index = self.lookupTable[column] else { preconditionFailure("A column '\(column)' does not exist.") } @@ -42,9 +42,9 @@ extension PSQLRow { /// - Parameters: /// - column: The column index to read the data from /// - type: The type to decode the data into - /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. + /// - Throws: The error of the decoding implementation. See also `PostgresDecodable` protocol for this. /// - Returns: The decoded value of Type T. - func decode(column index: Int, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { + func decode(column index: Int, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { precondition(index < self.data.columnCount) let column = self.columns[index] @@ -59,12 +59,12 @@ extension PSQLRow { extension PSQLRow { // TODO: Remove this function. Only here to keep the tests running as of today. - func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { try self.decode(column: column, as: type, jsonDecoder: JSONDecoder(), file: file, line: line) } // TODO: Remove this function. Only here to keep the tests running as of today. - func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { try self.decode(column: index, as: type, jsonDecoder: JSONDecoder(), file: file, line: line) } } diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index a461bb37..8d4bcc7c 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -20,7 +20,7 @@ struct PostgresCell: Equatable { extension PostgresCell { - func decode( + func decode( _: T.Type, context: PostgresDecodingContext, file: String = #file, diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index acc8d735..54694f25 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -31,7 +31,7 @@ extension PostgresData: PSQLEncodable { } } -extension PostgresData: PSQLDecodable { +extension PostgresData: PostgresDecodable { static func decode( from buffer: inout ByteBuffer, type: PostgresDataType, diff --git a/dev/generate-psqlrow-multi-decode.sh b/dev/generate-psqlrow-multi-decode.sh index e58b17b2..f2be1ad1 100755 --- a/dev/generate-psqlrow-multi-decode.sh +++ b/dev/generate-psqlrow-multi-decode.sh @@ -13,9 +13,9 @@ function gen() { echo " @inlinable" #echo " @_alwaysEmitIntoClient" - echo -n " func decode(_: (T0" From 3d5d25fd146a57afab9f3f49c075cd1171f04dee Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 21 Feb 2022 09:57:43 +0100 Subject: [PATCH 059/292] Rename PSQLCodable to PostgresCodable (#225) --- ...PSQLCodable.swift => Array+PostgresCodable.swift} | 6 +++--- ...+PSQLCodable.swift => Bool+PostgresCodable.swift} | 2 +- ...PSQLCodable.swift => Bytes+PostgresCodable.swift} | 6 +++--- ...+PSQLCodable.swift => Date+PostgresCodable.swift} | 2 +- ...QLCodable.swift => Decimal+PostgresCodable.swift} | 2 +- ...PSQLCodable.swift => Float+PostgresCodable.swift} | 4 ++-- ...t+PSQLCodable.swift => Int+PostgresCodable.swift} | 10 +++++----- ...+PSQLCodable.swift => JSON+PostgresCodable.swift} | 2 +- ...LCodable.swift => Optional+PostgresCodable.swift} | 4 ++-- ....swift => RawRepresentable+PostgresCodable.swift} | 2 +- ...SQLCodable.swift => String+PostgresCodable.swift} | 2 +- ...+PSQLCodable.swift => UUID+PostgresCodable.swift} | 2 +- .../New/{PSQLCodable.swift => PostgresCodable.swift} | 6 +++--- Sources/PostgresNIO/New/PostgresQuery.swift | 12 ++++++------ Sources/PostgresNIO/Postgres+PSQLCompat.swift | 4 ++-- Tests/IntegrationTests/PSQLIntegrationTests.swift | 2 +- .../New/Data/Bytes+PSQLCodableTests.swift | 2 +- .../New/Data/JSON+PSQLCodableTests.swift | 2 +- .../New/Data/Optional+PSQLCodableTests.swift | 4 ++-- .../New/Data/RawRepresentable+PSQLCodableTests.swift | 2 +- .../PostgresNIOTests/New/Messages/DataRowTests.swift | 4 ++-- Tests/PostgresNIOTests/New/PostgresQueryTests.swift | 4 ++-- 22 files changed, 43 insertions(+), 43 deletions(-) rename Sources/PostgresNIO/New/Data/{Array+PSQLCodable.swift => Array+PostgresCodable.swift} (96%) rename Sources/PostgresNIO/New/Data/{Bool+PSQLCodable.swift => Bool+PostgresCodable.swift} (97%) rename Sources/PostgresNIO/New/Data/{Bytes+PSQLCodable.swift => Bytes+PostgresCodable.swift} (91%) rename Sources/PostgresNIO/New/Data/{Date+PSQLCodable.swift => Date+PostgresCodable.swift} (98%) rename Sources/PostgresNIO/New/Data/{Decimal+PSQLCodable.swift => Decimal+PostgresCodable.swift} (97%) rename Sources/PostgresNIO/New/Data/{Float+PSQLCodable.swift => Float+PostgresCodable.swift} (97%) rename Sources/PostgresNIO/New/Data/{Int+PSQLCodable.swift => Int+PostgresCodable.swift} (97%) rename Sources/PostgresNIO/New/Data/{JSON+PSQLCodable.swift => JSON+PostgresCodable.swift} (96%) rename Sources/PostgresNIO/New/Data/{Optional+PSQLCodable.swift => Optional+PostgresCodable.swift} (91%) rename Sources/PostgresNIO/New/Data/{RawRepresentable+PSQLCodable.swift => RawRepresentable+PostgresCodable.swift} (91%) rename Sources/PostgresNIO/New/Data/{String+PSQLCodable.swift => String+PostgresCodable.swift} (97%) rename Sources/PostgresNIO/New/Data/{UUID+PSQLCodable.swift => UUID+PostgresCodable.swift} (98%) rename Sources/PostgresNIO/New/{PSQLCodable.swift => PostgresCodable.swift} (97%) diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift similarity index 96% rename from Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index fc8e3b3a..c68e6e27 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -2,7 +2,7 @@ import NIOCore import struct Foundation.UUID /// A type, of which arrays can be encoded into and decoded from a postgres binary format -protocol PSQLArrayElement: PSQLCodable { +protocol PSQLArrayElement: PostgresCodable { static var psqlArrayType: PostgresDataType { get } static var psqlArrayElementType: PostgresDataType { get } } @@ -67,7 +67,7 @@ extension UUID: PSQLArrayElement { static var psqlArrayElementType: PostgresDataType { .uuid } } -extension Array: PSQLEncodable where Element: PSQLArrayElement { +extension Array: PostgresEncodable where Element: PSQLArrayElement { var psqlType: PostgresDataType { Element.psqlArrayType } @@ -155,6 +155,6 @@ extension Array: PostgresDecodable where Element: PSQLArrayElement { } } -extension Array: PSQLCodable where Element: PSQLArrayElement { +extension Array: PostgresCodable where Element: PSQLArrayElement { } diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift similarity index 97% rename from Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift index ce0350a2..2e781bff 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -1,6 +1,6 @@ import NIOCore -extension Bool: PSQLCodable { +extension Bool: PostgresCodable { var psqlType: PostgresDataType { .bool } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift similarity index 91% rename from Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift index d7e0e804..8126b57a 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift @@ -2,7 +2,7 @@ import struct Foundation.Data import NIOCore import NIOFoundationCompat -extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { +extension PostgresEncodable where Self: Sequence, Self.Element == UInt8 { var psqlType: PostgresDataType { .bytea } @@ -19,7 +19,7 @@ extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { } } -extension ByteBuffer: PSQLCodable { +extension ByteBuffer: PostgresCodable { var psqlType: PostgresDataType { .bytea } @@ -46,7 +46,7 @@ extension ByteBuffer: PSQLCodable { } } -extension Data: PSQLCodable { +extension Data: PostgresCodable { var psqlType: PostgresDataType { .bytea } diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift similarity index 98% rename from Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift index d8d48915..680b4343 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.Date -extension Date: PSQLCodable { +extension Date: PostgresCodable { var psqlType: PostgresDataType { .timestamptz } diff --git a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift similarity index 97% rename from Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift index aa1569cc..3d9360b6 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.Decimal -extension Decimal: PSQLCodable { +extension Decimal: PostgresCodable { var psqlType: PostgresDataType { .numeric } diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift similarity index 97% rename from Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift index fd5abfb2..b3d6575a 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -1,6 +1,6 @@ import NIOCore -extension Float: PSQLCodable { +extension Float: PostgresCodable { var psqlType: PostgresDataType { .float4 } @@ -44,7 +44,7 @@ extension Float: PSQLCodable { } } -extension Double: PSQLCodable { +extension Double: PostgresCodable { var psqlType: PostgresDataType { .float8 } diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift similarity index 97% rename from Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift index ca373b78..0d6b258d 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -1,6 +1,6 @@ import NIOCore -extension UInt8: PSQLCodable { +extension UInt8: PostgresCodable { var psqlType: PostgresDataType { .char } @@ -35,7 +35,7 @@ extension UInt8: PSQLCodable { } } -extension Int16: PSQLCodable { +extension Int16: PostgresCodable { var psqlType: PostgresDataType { .int2 @@ -75,7 +75,7 @@ extension Int16: PSQLCodable { } } -extension Int32: PSQLCodable { +extension Int32: PostgresCodable { var psqlType: PostgresDataType { .int4 } @@ -119,7 +119,7 @@ extension Int32: PSQLCodable { } } -extension Int64: PSQLCodable { +extension Int64: PostgresCodable { var psqlType: PostgresDataType { .int8 } @@ -168,7 +168,7 @@ extension Int64: PSQLCodable { } } -extension Int: PSQLCodable { +extension Int: PostgresCodable { var psqlType: PostgresDataType { switch self.bitWidth { case Int32.bitWidth: diff --git a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift similarity index 96% rename from Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift index 972f11e7..cd291c71 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift @@ -5,7 +5,7 @@ import class Foundation.JSONDecoder private let JSONBVersionByte: UInt8 = 0x01 -extension PSQLCodable where Self: Codable { +extension PostgresCodable where Self: Codable { var psqlType: PostgresDataType { .jsonb } diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift similarity index 91% rename from Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift index 7ab857e0..080dc669 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift @@ -25,7 +25,7 @@ extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped. } } -extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { +extension Optional: PostgresEncodable where Wrapped: PostgresEncodable { var psqlType: PostgresDataType { switch self { case .some(let value): @@ -64,6 +64,6 @@ extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { } } -extension Optional: PSQLCodable where Wrapped: PSQLCodable, Wrapped.DecodableType == Wrapped { +extension Optional: PostgresCodable where Wrapped: PostgresCodable, Wrapped.DecodableType == Wrapped { } diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift similarity index 91% rename from Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift index 3a05a848..b853eac1 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLCodable where Self: RawRepresentable, RawValue: PSQLCodable { +extension PostgresCodable where Self: RawRepresentable, RawValue: PostgresCodable { var psqlType: PostgresDataType { self.rawValue.psqlType } diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift similarity index 97% rename from Sources/PostgresNIO/New/Data/String+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 481296cc..fba73b1a 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.UUID -extension String: PSQLCodable { +extension String: PostgresCodable { var psqlType: PostgresDataType { .text } diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift similarity index 98% rename from Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift rename to Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index bf5265f3..43177249 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -2,7 +2,7 @@ import NIOCore import struct Foundation.UUID import typealias Foundation.uuid_t -extension UUID: PSQLCodable { +extension UUID: PostgresCodable { var psqlType: PostgresDataType { .uuid diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift similarity index 97% rename from Sources/PostgresNIO/New/PSQLCodable.swift rename to Sources/PostgresNIO/New/PostgresCodable.swift index cc302f20..cbc1ca6a 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -2,7 +2,7 @@ import NIOCore import Foundation /// A type that can encode itself to a postgres wire binary representation. -protocol PSQLEncodable { +protocol PostgresEncodable { /// identifies the data type that we will encode into `byteBuffer` in `encode` var psqlType: PostgresDataType { get } @@ -68,9 +68,9 @@ extension PostgresDecodable { } /// A type that can be encoded into and decoded from a postgres binary format -protocol PSQLCodable: PSQLEncodable, PostgresDecodable {} +protocol PostgresCodable: PostgresEncodable, PostgresDecodable {} -extension PSQLEncodable { +extension PostgresEncodable { func encodeRaw( into buffer: inout ByteBuffer, context: PSQLEncodingContext diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 7c748e83..362288e5 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -23,7 +23,7 @@ extension PostgresQuery: ExpressibleByStringInterpolation { self.binds = PostgresBindings() } - mutating func appendBinding( + mutating func appendBinding( _ value: Value, context: PSQLEncodingContext ) throws { @@ -47,17 +47,17 @@ extension PostgresQuery { self.sql.append(contentsOf: literal) } - mutating func appendInterpolation(_ value: Value) throws { + mutating func appendInterpolation(_ value: Value) throws { try self.binds.append(value, context: .default) self.sql.append(contentsOf: "$\(self.binds.count)") } - mutating func appendInterpolation(_ value: Optional) throws { + mutating func appendInterpolation(_ value: Optional) throws { try self.binds.append(value, context: .default) self.sql.append(contentsOf: "$\(self.binds.count)") } - mutating func appendInterpolation( + mutating func appendInterpolation( _ value: Value, context: PSQLEncodingContext ) throws { @@ -86,7 +86,7 @@ struct PostgresBindings: Hashable { self.format = format } - init(value: Value) { + init(value: Value) { self.init(dataType: value.psqlType, format: value.psqlFormat) } } @@ -110,7 +110,7 @@ struct PostgresBindings: Hashable { self.bytes.reserveCapacity(128 * capacity) } - mutating func append( + mutating func append( _ value: Value, context: PSQLEncodingContext ) throws { diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 54694f25..46fa475e 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -1,6 +1,6 @@ import NIOCore -extension PostgresData: PSQLEncodable { +extension PostgresData: PostgresEncodable { var psqlType: PostgresDataType { self.type } @@ -44,7 +44,7 @@ extension PostgresData: PostgresDecodable { } } -extension PostgresData: PSQLCodable {} +extension PostgresData: PostgresCodable {} extension PSQLError { func toPostgresError() -> Error { diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 43e1e25a..61bdb136 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -298,7 +298,7 @@ final class IntegrationTests: XCTestCase { } func testRoundTripJSONB() { - struct Object: Codable, PSQLCodable { + struct Object: Codable, PostgresCodable { let foo: Int let bar: Int } diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index a3ad33a7..9747ec19 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -29,7 +29,7 @@ class Bytes_PSQLCodableTests: XCTestCase { } func testEncodeSequenceWhereElementUInt8() { - struct ByteSequence: Sequence, PSQLEncodable { + struct ByteSequence: Sequence, PostgresEncodable { typealias Element = UInt8 typealias Iterator = Array.Iterator diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index d17b139c..d5ade4c7 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -4,7 +4,7 @@ import NIOCore class JSON_PSQLCodableTests: XCTestCase { - struct Hello: Equatable, Codable, PSQLCodable { + struct Hello: Equatable, Codable, PostgresCodable { let hello: String init(name: String) { diff --git a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift index 62dbb9d7..1d689b0e 100644 --- a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift @@ -35,7 +35,7 @@ class Optional_PSQLCodableTests: XCTestCase { func testRoundTripSomeUUIDAsPSQLEncodable() { let value: Optional = UUID() - let encodable: PSQLEncodable = value + let encodable: PostgresEncodable = value var buffer = ByteBuffer() XCTAssertEqual(encodable.psqlType, .uuid) @@ -51,7 +51,7 @@ class Optional_PSQLCodableTests: XCTestCase { func testRoundTripNoneUUIDAsPSQLEncodable() { let value: Optional = .none - let encodable: PSQLEncodable = value + let encodable: PostgresEncodable = value var buffer = ByteBuffer() XCTAssertEqual(encodable.psqlType, .null) diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift index 712d8843..1e515f4c 100644 --- a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -4,7 +4,7 @@ import NIOCore class RawRepresentable_PSQLCodableTests: XCTestCase { - enum MyRawRepresentable: Int16, PSQLCodable { + enum MyRawRepresentable: Int16, PostgresCodable { case testing = 1 case staging = 2 case production = 3 diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index 7db44547..b59c7c87 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -114,9 +114,9 @@ class DataRowTests: XCTestCase { } extension DataRow: ExpressibleByArrayLiteral { - public typealias ArrayLiteralElement = PSQLEncodable + public typealias ArrayLiteralElement = PostgresEncodable - public init(arrayLiteral elements: PSQLEncodable...) { + public init(arrayLiteral elements: PostgresEncodable...) { var buffer = ByteBuffer() let encodingContext = PSQLEncodingContext(jsonEncoder: JSONEncoder()) diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift index 24123a54..80f52ea5 100644 --- a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -31,7 +31,7 @@ final class PostgresQueryTests: XCTestCase { } func testStringInterpolationWithCustomJSONEncoder() throws { - struct Foo: Codable, PSQLCodable { + struct Foo: Codable, PostgresCodable { var helloWorld: String } @@ -55,7 +55,7 @@ final class PostgresQueryTests: XCTestCase { } func testAllowUsersToGenerateLotsOfRows() throws { - struct Foo: Codable, PSQLCodable { + struct Foo: Codable, PostgresCodable { var helloWorld: String } From c98c808a2e493daadb090594bf1bf3fd44411954 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 21 Feb 2022 10:14:24 +0100 Subject: [PATCH 060/292] Rename PSQLRowSequence to PostgresRowSequence (#226) --- Sources/PostgresNIO/New/PSQLRowStream.swift | 4 ++-- ...PSQLRowSequence.swift => PostgresRowSequence.swift} | 10 +++++----- ...uenceTests.swift => PostgresRowSequenceTests.swift} | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) rename Sources/PostgresNIO/New/{PSQLRowSequence.swift => PostgresRowSequence.swift} (98%) rename Tests/PostgresNIOTests/New/{PSQLRowSequenceTests.swift => PostgresRowSequenceTests.swift} (98%) diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index d6aea9a1..e69d219b 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -64,7 +64,7 @@ final class PSQLRowStream { // MARK: Async Sequence #if swift(>=5.5) && canImport(_Concurrency) - func asyncSequence() -> PSQLRowSequence { + func asyncSequence() -> PostgresRowSequence { self.eventLoop.preconditionInEventLoop() guard case .waitingForConsumer(let bufferState) = self.downstreamState else { @@ -90,7 +90,7 @@ final class PSQLRowStream { self.downstreamState = .consumed(.failure(error)) } - return PSQLRowSequence(consumer) + return PostgresRowSequence(consumer) } func demand() { diff --git a/Sources/PostgresNIO/New/PSQLRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift similarity index 98% rename from Sources/PostgresNIO/New/PSQLRowSequence.swift rename to Sources/PostgresNIO/New/PostgresRowSequence.swift index 17ba1659..160cea02 100644 --- a/Sources/PostgresNIO/New/PSQLRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -5,7 +5,7 @@ import NIOConcurrencyHelpers /// An async sequence of ``PSQLRow``s. /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. -struct PSQLRowSequence: AsyncSequence { +struct PostgresRowSequence: AsyncSequence { typealias Element = PSQLRow typealias AsyncIterator = Iterator @@ -38,7 +38,7 @@ struct PSQLRowSequence: AsyncSequence { } } -extension PSQLRowSequence { +extension PostgresRowSequence { struct Iterator: AsyncIteratorProtocol { typealias Element = PSQLRow @@ -155,11 +155,11 @@ final class AsyncStreamConsumer { } } - func makeAsyncIterator() -> PSQLRowSequence.Iterator { + func makeAsyncIterator() -> PostgresRowSequence.Iterator { self.lock.withLock { self.state.createAsyncIterator() } - let iterator = PSQLRowSequence.Iterator(consumer: self) + let iterator = PostgresRowSequence.Iterator(consumer: self) return iterator } @@ -532,7 +532,7 @@ extension AsyncStreamConsumer { } } -extension PSQLRowSequence { +extension PostgresRowSequence { func collect() async throws -> [PSQLRow] { var result = [PSQLRow]() for try await row in self { diff --git a/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift similarity index 98% rename from Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift rename to Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 0dd935a6..d42beb85 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -5,7 +5,7 @@ import XCTest @testable import PostgresNIO #if swift(>=5.5.2) -final class PSQLRowSequenceTests: XCTestCase { +final class PostgresRowSequenceTests: XCTestCase { func testBackpressureWorks() async throws { let eventLoop = EmbeddedEventLoop() @@ -90,7 +90,7 @@ final class PSQLRowSequenceTests: XCTestCase { let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) - var iterator: PSQLRowSequence.Iterator? = rowSequence.makeAsyncIterator() + var iterator: PostgresRowSequence.Iterator? = rowSequence.makeAsyncIterator() iterator = nil XCTAssertEqual(dataSource.cancelCount, 1) @@ -112,7 +112,7 @@ final class PSQLRowSequenceTests: XCTestCase { ) promise.succeed(stream) - var rowSequence: PSQLRowSequence? = stream.asyncSequence() + var rowSequence: PostgresRowSequence? = stream.asyncSequence() rowSequence = nil XCTAssertEqual(dataSource.cancelCount, 1) From 7f53867076c46d404afa42fca81f48e25baeffb2 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Feb 2022 10:02:18 +0100 Subject: [PATCH 061/292] Rename PSQLEncodingContext to PostgresEncodingContext (#227) --- .../PostgresNIO/New/Data/Array+PostgresCodable.swift | 2 +- .../PostgresNIO/New/Data/Bool+PostgresCodable.swift | 2 +- .../PostgresNIO/New/Data/Bytes+PostgresCodable.swift | 6 +++--- .../PostgresNIO/New/Data/Date+PostgresCodable.swift | 2 +- .../New/Data/Decimal+PostgresCodable.swift | 2 +- .../PostgresNIO/New/Data/Float+PostgresCodable.swift | 4 ++-- .../PostgresNIO/New/Data/Int+PostgresCodable.swift | 10 +++++----- .../PostgresNIO/New/Data/JSON+PostgresCodable.swift | 2 +- .../New/Data/Optional+PostgresCodable.swift | 4 ++-- .../New/Data/RawRepresentable+PostgresCodable.swift | 2 +- .../New/Data/String+PostgresCodable.swift | 2 +- .../PostgresNIO/New/Data/UUID+PostgresCodable.swift | 2 +- Sources/PostgresNIO/New/PostgresCodable.swift | 12 ++++++------ Sources/PostgresNIO/New/PostgresQuery.swift | 6 +++--- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 4 ++-- .../New/Extensions/PSQLCoding+TestUtils.swift | 2 +- .../PostgresNIOTests/New/Messages/DataRowTests.swift | 2 +- 17 files changed, 33 insertions(+), 33 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index c68e6e27..875361e1 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -78,7 +78,7 @@ extension Array: PostgresEncodable where Element: PSQLArrayElement { func encode( into buffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { // 0 if empty, 1 if not buffer.writeInteger(self.isEmpty ? 0 : 1, as: UInt32.self) diff --git a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift index 2e781bff..9d9120b8 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -51,7 +51,7 @@ extension Bool: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeInteger(self ? 1 : 0, as: UInt8.self) } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift index 8126b57a..1c98948f 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift @@ -13,7 +13,7 @@ extension PostgresEncodable where Self: Sequence, Self.Element == UInt8 { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeBytes(self) } @@ -30,7 +30,7 @@ extension ByteBuffer: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { var copyOfSelf = self // dirty hack byteBuffer.writeBuffer(©OfSelf) @@ -57,7 +57,7 @@ extension Data: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeBytes(self) } diff --git a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift index 680b4343..cb440367 100644 --- a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -36,7 +36,7 @@ extension Date: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { let seconds = self.timeIntervalSince(Self._psqlDateStart) * Double(Self._microsecondsPerSecond) byteBuffer.writeInteger(Int64(seconds)) diff --git a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift index 3d9360b6..9159b311 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift @@ -34,7 +34,7 @@ extension Decimal: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { let numeric = PostgresNumeric(decimal: self) byteBuffer.writeInteger(numeric.ndigits) diff --git a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift index b3d6575a..94b70820 100644 --- a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -38,7 +38,7 @@ extension Float: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.psqlWriteFloat(self) } @@ -82,7 +82,7 @@ extension Double: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.psqlWriteDouble(self) } diff --git a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift index 0d6b258d..6d980a40 100644 --- a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -29,7 +29,7 @@ extension UInt8: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeInteger(self, as: UInt8.self) } @@ -69,7 +69,7 @@ extension Int16: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeInteger(self, as: Int16.self) } @@ -113,7 +113,7 @@ extension Int32: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeInteger(self, as: Int32.self) } @@ -162,7 +162,7 @@ extension Int64: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeInteger(self, as: Int64.self) } @@ -218,7 +218,7 @@ extension Int: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeInteger(self, as: Int.self) } diff --git a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift index cd291c71..9e5aeb18 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift @@ -35,7 +35,7 @@ extension PostgresCodable where Self: Codable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { byteBuffer.writeInteger(JSONBVersionByte) try context.jsonEncoder.encode(self, into: &byteBuffer) diff --git a/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift index 080dc669..0c098fad 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift @@ -46,14 +46,14 @@ extension Optional: PostgresEncodable where Wrapped: PostgresEncodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") } func encodeRaw( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { switch self { case .none: diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift index b853eac1..d05b179e 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -25,7 +25,7 @@ extension PostgresCodable where Self: RawRepresentable, RawValue: PostgresCodabl func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { try rawValue.encode(into: &byteBuffer, context: context) } diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index fba73b1a..538e2db5 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -12,7 +12,7 @@ extension String: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { byteBuffer.writeString(self) } diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index 43177249..95e21dd3 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -14,7 +14,7 @@ extension UUID: PostgresCodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { let uuid = self.uuid byteBuffer.writeBytes([ diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index cbc1ca6a..d961dd08 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -11,12 +11,12 @@ protocol PostgresEncodable { /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting /// the byte count. This method is called from the default `encodeRaw` implementation. - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws + func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) throws /// Encode the entity into the `byteBuffer` in Postgres binary format including its /// leading byte count. This method has a default implementation and may be overriden /// only for special cases, like `Optional`s. - func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws + func encodeRaw(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) throws } /// A type that can decode itself from a postgres wire binary representation. @@ -73,7 +73,7 @@ protocol PostgresCodable: PostgresEncodable, PostgresDecodable {} extension PostgresEncodable { func encodeRaw( into buffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { // The length of the parameter value, in bytes (this count does not include // itself). Can be zero. @@ -89,7 +89,7 @@ extension PostgresEncodable { } } -struct PSQLEncodingContext { +struct PostgresEncodingContext { let jsonEncoder: JSONEncoder init(jsonEncoder: JSONEncoder) { @@ -97,8 +97,8 @@ struct PSQLEncodingContext { } } -extension PSQLEncodingContext where JSONEncoder == Foundation.JSONEncoder { - static let `default` = PSQLEncodingContext(jsonEncoder: JSONEncoder()) +extension PostgresEncodingContext where JSONEncoder == Foundation.JSONEncoder { + static let `default` = PostgresEncodingContext(jsonEncoder: JSONEncoder()) } struct PostgresDecodingContext { diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 362288e5..62a74cce 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -25,7 +25,7 @@ extension PostgresQuery: ExpressibleByStringInterpolation { mutating func appendBinding( _ value: Value, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { try self.binds.append(value, context: context) } @@ -59,7 +59,7 @@ extension PostgresQuery { mutating func appendInterpolation( _ value: Value, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { try self.binds.append(value, context: context) self.sql.append(contentsOf: "$\(self.binds.count)") @@ -112,7 +112,7 @@ struct PostgresBindings: Hashable { mutating func append( _ value: Value, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { try value.encodeRaw(into: &self.bytes, context: context) self.metadata.append(.init(value: value)) diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 46fa475e..f2d97112 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -11,7 +11,7 @@ extension PostgresData: PostgresEncodable { func encode( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) throws { preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") } @@ -19,7 +19,7 @@ extension PostgresData: PostgresEncodable { // encoding func encodeRaw( into byteBuffer: inout ByteBuffer, - context: PSQLEncodingContext + context: PostgresEncodingContext ) { switch self.value { case .none: diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift index 5e561b8e..212a18bd 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift @@ -7,7 +7,7 @@ extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { } } -extension PSQLEncodingContext where JSONEncoder == Foundation.JSONEncoder { +extension PostgresEncodingContext where JSONEncoder == Foundation.JSONEncoder { static func forTests(jsonEncoder: JSONEncoder = JSONEncoder()) -> Self { Self(jsonEncoder: jsonEncoder) } diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index b59c7c87..643c8a28 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -119,7 +119,7 @@ extension DataRow: ExpressibleByArrayLiteral { public init(arrayLiteral elements: PostgresEncodable...) { var buffer = ByteBuffer() - let encodingContext = PSQLEncodingContext(jsonEncoder: JSONEncoder()) + let encodingContext = PostgresEncodingContext(jsonEncoder: JSONEncoder()) elements.forEach { element in try! element.encodeRaw(into: &buffer, context: encodingContext) } From 041842ba58fff22405f1a4a91204ed1579d5b4d7 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Feb 2022 10:19:34 +0100 Subject: [PATCH 062/292] Move all PostgresConnection code into one file (#228) --- .../PostgresConnection+Authenticate.swift | 24 -- .../PostgresConnection+Connect.swift | 29 -- .../PostgresConnection+Database.swift | 149 ---------- .../PostgresConnection+Notifications.swift | 70 ----- .../Connection/PostgresConnection.swift | 267 +++++++++++++++++- 5 files changed, 266 insertions(+), 273 deletions(-) delete mode 100644 Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift delete mode 100644 Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift delete mode 100644 Sources/PostgresNIO/Connection/PostgresConnection+Database.swift delete mode 100644 Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift deleted file mode 100644 index d58943ba..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift +++ /dev/null @@ -1,24 +0,0 @@ -import NIOCore -import Logging - -extension PostgresConnection { - public func authenticate( - username: String, - database: String? = nil, - password: String? = nil, - logger: Logger = .init(label: "codes.vapor.postgres") - ) -> EventLoopFuture { - let authContext = AuthContext( - username: username, - password: password, - database: database) - let outgoing = PSQLOutgoingEvent.authenticate(authContext) - self.underlying.channel.triggerUserOutboundEvent(outgoing, promise: nil) - - return self.underlying.channel.pipeline.handler(type: PSQLEventsHandler.self).flatMap { handler in - handler.authenticateFuture - }.flatMapErrorThrowing { error in - throw error.asAppropriatePostgresError - } - } -} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift deleted file mode 100644 index bcedd2fb..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift +++ /dev/null @@ -1,29 +0,0 @@ -import NIOCore -import NIOSSL -import Logging - -extension PostgresConnection { - public static func connect( - to socketAddress: SocketAddress, - tlsConfiguration: TLSConfiguration? = nil, - serverHostname: String? = nil, - logger: Logger = .init(label: "codes.vapor.postgres"), - on eventLoop: EventLoop - ) -> EventLoopFuture { - let configuration = PSQLConnection.Configuration( - connection: .resolved(address: socketAddress, serverName: serverHostname), - authentication: nil, - tlsConfiguration: tlsConfiguration - ) - - return PSQLConnection.connect( - configuration: configuration, - logger: logger, - on: eventLoop - ).map { connection in - PostgresConnection(underlying: connection, logger: logger) - }.flatMapErrorThrowing { error in - throw error.asAppropriatePostgresError - } - } -} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift deleted file mode 100644 index 8b82c1b4..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ /dev/null @@ -1,149 +0,0 @@ -import NIOCore -import Logging -import struct Foundation.Data - -extension PostgresConnection: PostgresDatabase { - public func send( - _ request: PostgresRequest, - logger: Logger - ) -> EventLoopFuture { - guard let command = request as? PostgresCommands else { - preconditionFailure("\(#function) requires an instance of PostgresCommands. This will be a compile-time error in the future.") - } - - let resultFuture: EventLoopFuture - - switch command { - case .query(let query, let binds, let onMetadata, let onRow): - var psqlQuery = PostgresQuery(unsafeSQL: query, binds: .init(capacity: binds.count)) - binds.forEach { - // We can bang the try here as encoding PostgresData does not throw. The throw - // is just an option for the protocol. - try! psqlQuery.appendBinding($0, context: .default) - } - - resultFuture = self.underlying.query(psqlQuery, logger: logger).flatMap { stream in - let fields = stream.rowDescription.map { column in - PostgresMessage.RowDescription.Field( - name: column.name, - tableOID: UInt32(column.tableOID), - columnAttributeNumber: column.columnAttributeNumber, - dataType: PostgresDataType(UInt32(column.dataType.rawValue)), - dataTypeSize: column.dataTypeSize, - dataTypeModifier: column.dataTypeModifier, - formatCode: .init(psqlFormatCode: column.format) - ) - } - - let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) - return stream.iterateRowsWithoutBackpressureOption(lookupTable: lookupTable, onRow: onRow).map { _ in - onMetadata(PostgresQueryMetadata(string: stream.commandTag)!) - } - } - case .queryAll(let query, let binds, let onResult): - var psqlQuery = PostgresQuery(unsafeSQL: query, binds: .init(capacity: binds.count)) - binds.forEach { - // We can bang the try here as encoding PostgresData does not throw. The throw - // is just an option for the protocol. - try! psqlQuery.appendBinding($0, context: .default) - } - - resultFuture = self.underlying.query(psqlQuery, logger: logger).flatMap { rows in - let fields = rows.rowDescription.map { column in - PostgresMessage.RowDescription.Field( - name: column.name, - tableOID: UInt32(column.tableOID), - columnAttributeNumber: column.columnAttributeNumber, - dataType: PostgresDataType(UInt32(column.dataType.rawValue)), - dataTypeSize: column.dataTypeSize, - dataTypeModifier: column.dataTypeModifier, - formatCode: .init(psqlFormatCode: column.format) - ) - } - - let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) - return rows.all().map { allrows in - let r = allrows.map { psqlRow -> PostgresRow in - let columns = psqlRow.data.map { - PostgresMessage.DataRow.Column(value: $0) - } - return PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) - } - - onResult(.init(metadata: PostgresQueryMetadata(string: rows.commandTag)!, rows: r)) - } - } - - case .prepareQuery(let request): - resultFuture = self.underlying.prepareStatement(request.query, with: request.name, logger: self.logger).map { - request.prepared = PreparedQuery(underlying: $0, database: self) - } - case .executePreparedStatement(let preparedQuery, let binds, let onRow): - var bindings = PostgresBindings() - binds.forEach { data in - try! bindings.append(data, context: .default) - } - - let statement = PSQLExecuteStatement( - name: preparedQuery.underlying.name, - binds: bindings, - rowDescription: preparedQuery.underlying.rowDescription - ) - - resultFuture = self.underlying.execute(statement, logger: logger).flatMap { rows in - guard let lookupTable = preparedQuery.lookupTable else { - return self.eventLoop.makeSucceededFuture(()) - } - - return rows.iterateRowsWithoutBackpressureOption(lookupTable: lookupTable, onRow: onRow) - } - } - - return resultFuture.flatMapErrorThrowing { error in - throw error.asAppropriatePostgresError - } - } - - public func withConnection(_ closure: (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture { - closure(self) - } -} - -internal enum PostgresCommands: PostgresRequest { - case query(query: String, - binds: [PostgresData], - onMetadata: (PostgresQueryMetadata) -> () = { _ in }, - onRow: (PostgresRow) throws -> ()) - case queryAll(query: String, - binds: [PostgresData], - onResult: (PostgresQueryResult) -> ()) - case prepareQuery(request: PrepareQueryRequest) - case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: (PostgresRow) throws -> ()) - - func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { - fatalError("This function must not be called") - } - - func start() throws -> [PostgresMessage] { - fatalError("This function must not be called") - } - - func log(to logger: Logger) { - fatalError("This function must not be called") - } -} - -extension PSQLRowStream { - - func iterateRowsWithoutBackpressureOption(lookupTable: PostgresRow.LookupTable, onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { - self.onRow { psqlRow in - let columns = psqlRow.data.map { - PostgresMessage.DataRow.Column(value: $0) - } - - let row = PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) - try onRow(row) - } - } - -} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift deleted file mode 100644 index 9a21437d..00000000 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift +++ /dev/null @@ -1,70 +0,0 @@ -import NIOCore -import Logging - -/// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. -public final class PostgresListenContext { - var stopper: (() -> Void)? - - /// Detach this listener so it no longer receives notifications. Other listeners, including those for the same channel, are unaffected. `UNLISTEN` is not sent; you are responsible for issuing an `UNLISTEN` query yourself if it is appropriate for your application. - public func stop() { - stopper?() - stopper = nil - } -} - -extension PostgresConnection { - /// Add a handler for NotificationResponse messages on a certain channel. This is used in conjunction with PostgreSQL's `LISTEN`/`NOTIFY` support: to listen on a channel, you add a listener using this method to handle the NotificationResponse messages, then issue a `LISTEN` query to instruct PostgreSQL to begin sending NotificationResponse messages. - @discardableResult - public func addListener(channel: String, handler notificationHandler: @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void) -> PostgresListenContext { - - let listenContext = PostgresListenContext() - - self.underlying.channel.pipeline.handler(type: PSQLChannelHandler.self).whenSuccess { handler in - if self.notificationListeners[channel] != nil { - self.notificationListeners[channel]!.append((listenContext, notificationHandler)) - } - else { - self.notificationListeners[channel] = [(listenContext, notificationHandler)] - } - } - - listenContext.stopper = { [weak self, weak listenContext] in - // self is weak, since the connection can long be gone, when the listeners stop is - // triggered. listenContext must be weak to prevent a retain cycle - - self?.underlying.channel.eventLoop.execute { - guard - let self = self, // the connection is already gone - var listeners = self.notificationListeners[channel] // we don't have the listeners for this topic ¯\_(ツ)_/¯ - else { - return - } - - assert(listeners.filter { $0.0 === listenContext }.count <= 1, "Listeners can not appear twice in a channel!") - listeners.removeAll(where: { $0.0 === listenContext }) // just in case a listener shows up more than once in a release build, remove all, not just first - self.notificationListeners[channel] = listeners.isEmpty ? nil : listeners - } - } - - return listenContext - } -} - -extension PostgresConnection: PSQLChannelHandlerNotificationDelegate { - func notificationReceived(_ notification: PSQLBackendMessage.NotificationResponse) { - self.underlying.eventLoop.assertInEventLoop() - - guard let listeners = self.notificationListeners[notification.channel] else { - return - } - - let postgresNotification = PostgresMessage.NotificationResponse( - backendPID: notification.backendPID, - channel: notification.channel, - payload: notification.payload) - - listeners.forEach { (listenContext, handler) in - handler(listenContext, postgresNotification) - } - } -} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index c400711e..5f00ab79 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -1,6 +1,6 @@ import NIOCore +import NIOSSL import Logging -import struct Foundation.UUID public final class PostgresConnection { let underlying: PSQLConnection @@ -43,3 +43,268 @@ public final class PostgresConnection { return self.underlying.close() } } + +// MARK: Connect + +extension PostgresConnection { + public static func connect( + to socketAddress: SocketAddress, + tlsConfiguration: TLSConfiguration? = nil, + serverHostname: String? = nil, + logger: Logger = .init(label: "codes.vapor.postgres"), + on eventLoop: EventLoop + ) -> EventLoopFuture { + let configuration = PSQLConnection.Configuration( + connection: .resolved(address: socketAddress, serverName: serverHostname), + authentication: nil, + tlsConfiguration: tlsConfiguration + ) + + return PSQLConnection.connect( + configuration: configuration, + logger: logger, + on: eventLoop + ).map { connection in + PostgresConnection(underlying: connection, logger: logger) + }.flatMapErrorThrowing { error in + throw error.asAppropriatePostgresError + } + } + + public func authenticate( + username: String, + database: String? = nil, + password: String? = nil, + logger: Logger = .init(label: "codes.vapor.postgres") + ) -> EventLoopFuture { + let authContext = AuthContext( + username: username, + password: password, + database: database) + let outgoing = PSQLOutgoingEvent.authenticate(authContext) + self.underlying.channel.triggerUserOutboundEvent(outgoing, promise: nil) + + return self.underlying.channel.pipeline.handler(type: PSQLEventsHandler.self).flatMap { handler in + handler.authenticateFuture + }.flatMapErrorThrowing { error in + throw error.asAppropriatePostgresError + } + } +} + +// MARK: PostgresDatabase + +extension PostgresConnection: PostgresDatabase { + public func send( + _ request: PostgresRequest, + logger: Logger + ) -> EventLoopFuture { + guard let command = request as? PostgresCommands else { + preconditionFailure("\(#function) requires an instance of PostgresCommands. This will be a compile-time error in the future.") + } + + let resultFuture: EventLoopFuture + + switch command { + case .query(let query, let binds, let onMetadata, let onRow): + var psqlQuery = PostgresQuery(unsafeSQL: query, binds: .init(capacity: binds.count)) + binds.forEach { + // We can bang the try here as encoding PostgresData does not throw. The throw + // is just an option for the protocol. + try! psqlQuery.appendBinding($0, context: .default) + } + + resultFuture = self.underlying.query(psqlQuery, logger: logger).flatMap { stream in + let fields = stream.rowDescription.map { column in + PostgresMessage.RowDescription.Field( + name: column.name, + tableOID: UInt32(column.tableOID), + columnAttributeNumber: column.columnAttributeNumber, + dataType: PostgresDataType(UInt32(column.dataType.rawValue)), + dataTypeSize: column.dataTypeSize, + dataTypeModifier: column.dataTypeModifier, + formatCode: .init(psqlFormatCode: column.format) + ) + } + + let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) + return stream.iterateRowsWithoutBackpressureOption(lookupTable: lookupTable, onRow: onRow).map { _ in + onMetadata(PostgresQueryMetadata(string: stream.commandTag)!) + } + } + case .queryAll(let query, let binds, let onResult): + var psqlQuery = PostgresQuery(unsafeSQL: query, binds: .init(capacity: binds.count)) + binds.forEach { + // We can bang the try here as encoding PostgresData does not throw. The throw + // is just an option for the protocol. + try! psqlQuery.appendBinding($0, context: .default) + } + + resultFuture = self.underlying.query(psqlQuery, logger: logger).flatMap { rows in + let fields = rows.rowDescription.map { column in + PostgresMessage.RowDescription.Field( + name: column.name, + tableOID: UInt32(column.tableOID), + columnAttributeNumber: column.columnAttributeNumber, + dataType: PostgresDataType(UInt32(column.dataType.rawValue)), + dataTypeSize: column.dataTypeSize, + dataTypeModifier: column.dataTypeModifier, + formatCode: .init(psqlFormatCode: column.format) + ) + } + + let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) + return rows.all().map { allrows in + let r = allrows.map { psqlRow -> PostgresRow in + let columns = psqlRow.data.map { + PostgresMessage.DataRow.Column(value: $0) + } + return PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) + } + + onResult(.init(metadata: PostgresQueryMetadata(string: rows.commandTag)!, rows: r)) + } + } + + case .prepareQuery(let request): + resultFuture = self.underlying.prepareStatement(request.query, with: request.name, logger: self.logger).map { + request.prepared = PreparedQuery(underlying: $0, database: self) + } + case .executePreparedStatement(let preparedQuery, let binds, let onRow): + var bindings = PostgresBindings() + binds.forEach { data in + try! bindings.append(data, context: .default) + } + + let statement = PSQLExecuteStatement( + name: preparedQuery.underlying.name, + binds: bindings, + rowDescription: preparedQuery.underlying.rowDescription + ) + + resultFuture = self.underlying.execute(statement, logger: logger).flatMap { rows in + guard let lookupTable = preparedQuery.lookupTable else { + return self.eventLoop.makeSucceededFuture(()) + } + + return rows.iterateRowsWithoutBackpressureOption(lookupTable: lookupTable, onRow: onRow) + } + } + + return resultFuture.flatMapErrorThrowing { error in + throw error.asAppropriatePostgresError + } + } + + public func withConnection(_ closure: (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture { + closure(self) + } +} + +internal enum PostgresCommands: PostgresRequest { + case query(query: String, + binds: [PostgresData], + onMetadata: (PostgresQueryMetadata) -> () = { _ in }, + onRow: (PostgresRow) throws -> ()) + case queryAll(query: String, + binds: [PostgresData], + onResult: (PostgresQueryResult) -> ()) + case prepareQuery(request: PrepareQueryRequest) + case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: (PostgresRow) throws -> ()) + + func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { + fatalError("This function must not be called") + } + + func start() throws -> [PostgresMessage] { + fatalError("This function must not be called") + } + + func log(to logger: Logger) { + fatalError("This function must not be called") + } +} + +extension PSQLRowStream { + + func iterateRowsWithoutBackpressureOption(lookupTable: PostgresRow.LookupTable, onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + self.onRow { psqlRow in + let columns = psqlRow.data.map { + PostgresMessage.DataRow.Column(value: $0) + } + + let row = PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) + try onRow(row) + } + } +} + +// MARK: Notifications + +/// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. +public final class PostgresListenContext { + var stopper: (() -> Void)? + + /// Detach this listener so it no longer receives notifications. Other listeners, including those for the same channel, are unaffected. `UNLISTEN` is not sent; you are responsible for issuing an `UNLISTEN` query yourself if it is appropriate for your application. + public func stop() { + stopper?() + stopper = nil + } +} + +extension PostgresConnection { + /// Add a handler for NotificationResponse messages on a certain channel. This is used in conjunction with PostgreSQL's `LISTEN`/`NOTIFY` support: to listen on a channel, you add a listener using this method to handle the NotificationResponse messages, then issue a `LISTEN` query to instruct PostgreSQL to begin sending NotificationResponse messages. + @discardableResult + public func addListener(channel: String, handler notificationHandler: @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void) -> PostgresListenContext { + + let listenContext = PostgresListenContext() + + self.underlying.channel.pipeline.handler(type: PSQLChannelHandler.self).whenSuccess { handler in + if self.notificationListeners[channel] != nil { + self.notificationListeners[channel]!.append((listenContext, notificationHandler)) + } + else { + self.notificationListeners[channel] = [(listenContext, notificationHandler)] + } + } + + listenContext.stopper = { [weak self, weak listenContext] in + // self is weak, since the connection can long be gone, when the listeners stop is + // triggered. listenContext must be weak to prevent a retain cycle + + self?.underlying.channel.eventLoop.execute { + guard + let self = self, // the connection is already gone + var listeners = self.notificationListeners[channel] // we don't have the listeners for this topic ¯\_(ツ)_/¯ + else { + return + } + + assert(listeners.filter { $0.0 === listenContext }.count <= 1, "Listeners can not appear twice in a channel!") + listeners.removeAll(where: { $0.0 === listenContext }) // just in case a listener shows up more than once in a release build, remove all, not just first + self.notificationListeners[channel] = listeners.isEmpty ? nil : listeners + } + } + + return listenContext + } +} + +extension PostgresConnection: PSQLChannelHandlerNotificationDelegate { + func notificationReceived(_ notification: PSQLBackendMessage.NotificationResponse) { + self.underlying.eventLoop.assertInEventLoop() + + guard let listeners = self.notificationListeners[notification.channel] else { + return + } + + let postgresNotification = PostgresMessage.NotificationResponse( + backendPID: notification.backendPID, + channel: notification.channel, + payload: notification.payload) + + listeners.forEach { (listenContext, handler) in + handler(listenContext, postgresNotification) + } + } +} From 8657fbbffb72a1e7d1ead306aab7b45f1094bffa Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Feb 2022 15:05:53 +0100 Subject: [PATCH 063/292] Deprecate unused PostgresMessages (#229) --- .../{Message => Deprecated}/PostgresMessage+Authentication.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Bind.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Close.swift | 1 + .../PostgresMessage+CommandComplete.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Describe.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Execute.swift | 1 + .../PostgresMessage+ParameterDescription.swift | 1 + .../PostgresMessage+ParameterStatus.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Parse.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Password.swift | 1 + .../{Message => Deprecated}/PostgresMessage+ReadyForQuery.swift | 1 + .../{Message => Deprecated}/PostgresMessage+SASLResponse.swift | 1 + .../{Message => Deprecated}/PostgresMessage+SSLRequest.swift | 1 + .../{Message => Deprecated}/PostgresMessage+SimpleQuery.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Startup.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Sync.swift | 1 + .../{Message => Deprecated}/PostgresMessage+Terminate.swift | 1 + .../{Message => Deprecated}/PostgresMessageDecoder.swift | 1 + .../{Message => Deprecated}/PostgresMessageEncoder.swift | 1 + Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift | 1 + 20 files changed, 20 insertions(+) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Authentication.swift (98%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Bind.swift (97%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Close.swift (94%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+CommandComplete.swift (90%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Describe.swift (94%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Execute.swift (92%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+ParameterDescription.swift (93%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+ParameterStatus.swift (93%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Parse.swift (94%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Password.swift (92%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+ReadyForQuery.swift (94%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+SASLResponse.swift (97%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+SSLRequest.swift (92%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+SimpleQuery.swift (87%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Startup.swift (96%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Sync.swift (85%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessage+Terminate.swift (78%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessageDecoder.swift (97%) rename Sources/PostgresNIO/{Message => Deprecated}/PostgresMessageEncoder.swift (94%) diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Authentication.swift similarity index 98% rename from Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Authentication.swift index 44523a5c..da7c25d5 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Authentication.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Authentication.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Authentication request returned by the server. + @available(*, deprecated, message: "Will be removed from public API") public enum Authentication: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .authentication diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Bind.swift similarity index 97% rename from Sources/PostgresNIO/Message/PostgresMessage+Bind.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Bind.swift index ca8d4aa8..5ff4bbf0 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Bind.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Bind.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a Bind command. + @available(*, deprecated, message: "Will be removed from public API") public struct Bind: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .bind diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Close.swift similarity index 94% rename from Sources/PostgresNIO/Message/PostgresMessage+Close.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Close.swift index 9e5dd99e..9bcc8aa1 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Close.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Close.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a Close Command + @available(*, deprecated, message: "Will be removed from public API") public struct Close: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .close diff --git a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+CommandComplete.swift similarity index 90% rename from Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+CommandComplete.swift index 406dc036..c9370402 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+CommandComplete.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a Close command. + @available(*, deprecated, message: "Will be removed from public API") public struct CommandComplete: PostgresMessageType { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> CommandComplete { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Describe.swift similarity index 94% rename from Sources/PostgresNIO/Message/PostgresMessage+Describe.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Describe.swift index 8c3bc8f5..787355db 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Describe.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Describe.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a Describe command. + @available(*, deprecated, message: "Will be removed from public API") public struct Describe: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .describe diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Execute.swift similarity index 92% rename from Sources/PostgresNIO/Message/PostgresMessage+Execute.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Execute.swift index 4b8bc999..39b447a4 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Execute.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Execute.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as an Execute command. + @available(*, deprecated, message: "Will be removed from public API") public struct Execute: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .execute diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterDescription.swift similarity index 93% rename from Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterDescription.swift index 3dfdb8e1..89e67682 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterDescription.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a parameter description. + @available(*, deprecated, message: "Will be removed from public API") public struct ParameterDescription: PostgresMessageType { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ParameterDescription { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterStatus.swift similarity index 93% rename from Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterStatus.swift index 5e2f5881..5ad6f95e 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterStatus.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+ParameterStatus.swift @@ -1,6 +1,7 @@ import NIOCore extension PostgresMessage { + @available(*, deprecated, message: "Will be removed from public API") public struct ParameterStatus: PostgresMessageType, CustomStringConvertible { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ParameterStatus { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Parse.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Parse.swift similarity index 94% rename from Sources/PostgresNIO/Message/PostgresMessage+Parse.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Parse.swift index 030076d0..8fb5a1ff 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Parse.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Parse.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a Parse command. + @available(*, deprecated, message: "Will be removed from public API") public struct Parse: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .parse diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Password.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Password.swift similarity index 92% rename from Sources/PostgresNIO/Message/PostgresMessage+Password.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Password.swift index 5b2cef63..cafe9cda 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Password.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Password.swift @@ -4,6 +4,7 @@ extension PostgresMessage { /// Identifies the message as a password response. Note that this is also used for /// GSSAPI and SSPI response messages (which is really a design error, since the contained /// data is not a null-terminated string in that case, but can be arbitrary binary data). + @available(*, deprecated, message: "Will be removed from public API") public struct Password: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .passwordMessage diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+ReadyForQuery.swift similarity index 94% rename from Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+ReadyForQuery.swift index c46047dd..5afc0910 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ReadyForQuery.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+ReadyForQuery.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message type. ReadyForQuery is sent whenever the backend is ready for a new query cycle. + @available(*, deprecated, message: "Will be removed from public API") public struct ReadyForQuery: CustomStringConvertible { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ReadyForQuery { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift similarity index 97% rename from Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift index 553edc2c..dba414ce 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SASLResponse.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// SASL ongoing challenge response message sent by the client. + @available(*, deprecated, message: "Will be removed from public API") public struct SASLResponse: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .saslResponse diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+SSLRequest.swift similarity index 92% rename from Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+SSLRequest.swift index a636f23f..ee504932 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SSLRequest.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+SSLRequest.swift @@ -3,6 +3,7 @@ import NIOCore extension PostgresMessage { /// A message asking the PostgreSQL server if SSL is supported /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 + @available(*, deprecated, message: "Will be removed from public API") public struct SSLRequest: PostgresMessageType { /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, /// and 5679 in the least significant 16 bits. diff --git a/Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+SimpleQuery.swift similarity index 87% rename from Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+SimpleQuery.swift index 7b1ec2f9..a0a6cfcf 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+SimpleQuery.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+SimpleQuery.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a simple query. + @available(*, deprecated, message: "Will be removed from public API") public struct SimpleQuery: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .query diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Startup.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Startup.swift similarity index 96% rename from Sources/PostgresNIO/Message/PostgresMessage+Startup.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Startup.swift index d4d09009..e9762439 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Startup.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Startup.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// First message sent from the frontend during startup. + @available(*, deprecated, message: "Will be removed from public API") public struct Startup: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .none diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Sync.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Sync.swift similarity index 85% rename from Sources/PostgresNIO/Message/PostgresMessage+Sync.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Sync.swift index 37d54dd7..0560ef7a 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Sync.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Sync.swift @@ -2,6 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a Bind command. + @available(*, deprecated, message: "Will be removed from public API") public struct Sync: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { return .sync diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+Terminate.swift similarity index 78% rename from Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessage+Terminate.swift index 5e34665a..afeae5bf 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Terminate.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+Terminate.swift @@ -1,6 +1,7 @@ import NIOCore extension PostgresMessage { + @available(*, deprecated, message: "Will be removed from public API") public struct Terminate: PostgresMessageType { public static var identifier: PostgresMessage.Identifier { .terminate diff --git a/Sources/PostgresNIO/Message/PostgresMessageDecoder.swift b/Sources/PostgresNIO/Deprecated/PostgresMessageDecoder.swift similarity index 97% rename from Sources/PostgresNIO/Message/PostgresMessageDecoder.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessageDecoder.swift index 53ce73de..e092c234 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageDecoder.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessageDecoder.swift @@ -1,6 +1,7 @@ import NIOCore import Logging +@available(*, deprecated, message: "Will be removed from public API") public final class PostgresMessageDecoder: ByteToMessageDecoder { /// See `ByteToMessageDecoder`. public typealias InboundOut = PostgresMessage diff --git a/Sources/PostgresNIO/Message/PostgresMessageEncoder.swift b/Sources/PostgresNIO/Deprecated/PostgresMessageEncoder.swift similarity index 94% rename from Sources/PostgresNIO/Message/PostgresMessageEncoder.swift rename to Sources/PostgresNIO/Deprecated/PostgresMessageEncoder.swift index 19f467a4..8dd4c38d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageEncoder.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessageEncoder.swift @@ -1,6 +1,7 @@ import NIOCore import Logging +@available(*, deprecated, message: "Will be removed from public API") public final class PostgresMessageEncoder: MessageToByteEncoder { /// See `MessageToByteEncoder`. public typealias OutboundIn = PostgresMessage diff --git a/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift b/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift index e9a970ef..d4557a55 100644 --- a/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift +++ b/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift @@ -3,6 +3,7 @@ import XCTest import NIOTestUtils class PostgresMessageDecoderTests: XCTestCase { + @available(*, deprecated, message: "Tests deprecated API") func testMessageDecoder() { let sample: [UInt8] = [ 0x52, // R - authentication From c493f0e8d1fd09fe18c657938c0221c5c27f2fd7 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Feb 2022 22:04:21 +0100 Subject: [PATCH 064/292] Cleanup encoding and decoding (#230) --- .../Connection/PostgresConnection.swift | 35 +-- .../New/Data/Optional+PostgresCodable.swift | 69 ----- .../New/PSQLRow-multi-decode.swift | 240 +++++++++--------- Sources/PostgresNIO/New/PostgresCell.swift | 2 +- Sources/PostgresNIO/New/PostgresCodable.swift | 50 +++- Sources/PostgresNIO/New/PostgresQuery.swift | 31 ++- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 46 ---- .../PostgresNIO/PostgresDatabase+Query.swift | 5 +- .../New/Data/Optional+PSQLCodableTests.swift | 67 ----- .../New/PostgresCodableTests.swift | 64 +++++ .../New/PostgresQueryTests.swift | 2 +- dev/generate-psqlrow-multi-decode.sh | 4 +- 12 files changed, 261 insertions(+), 354 deletions(-) delete mode 100644 Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift delete mode 100644 Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift create mode 100644 Tests/PostgresNIOTests/New/PostgresCodableTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 5f00ab79..fcb72953 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -106,15 +106,8 @@ extension PostgresConnection: PostgresDatabase { let resultFuture: EventLoopFuture switch command { - case .query(let query, let binds, let onMetadata, let onRow): - var psqlQuery = PostgresQuery(unsafeSQL: query, binds: .init(capacity: binds.count)) - binds.forEach { - // We can bang the try here as encoding PostgresData does not throw. The throw - // is just an option for the protocol. - try! psqlQuery.appendBinding($0, context: .default) - } - - resultFuture = self.underlying.query(psqlQuery, logger: logger).flatMap { stream in + case .query(let query, let onMetadata, let onRow): + resultFuture = self.underlying.query(query, logger: logger).flatMap { stream in let fields = stream.rowDescription.map { column in PostgresMessage.RowDescription.Field( name: column.name, @@ -132,15 +125,8 @@ extension PostgresConnection: PostgresDatabase { onMetadata(PostgresQueryMetadata(string: stream.commandTag)!) } } - case .queryAll(let query, let binds, let onResult): - var psqlQuery = PostgresQuery(unsafeSQL: query, binds: .init(capacity: binds.count)) - binds.forEach { - // We can bang the try here as encoding PostgresData does not throw. The throw - // is just an option for the protocol. - try! psqlQuery.appendBinding($0, context: .default) - } - - resultFuture = self.underlying.query(psqlQuery, logger: logger).flatMap { rows in + case .queryAll(let query, let onResult): + resultFuture = self.underlying.query(query, logger: logger).flatMap { rows in let fields = rows.rowDescription.map { column in PostgresMessage.RowDescription.Field( name: column.name, @@ -171,10 +157,8 @@ extension PostgresConnection: PostgresDatabase { request.prepared = PreparedQuery(underlying: $0, database: self) } case .executePreparedStatement(let preparedQuery, let binds, let onRow): - var bindings = PostgresBindings() - binds.forEach { data in - try! bindings.append(data, context: .default) - } + var bindings = PostgresBindings(capacity: binds.count) + binds.forEach { bindings.append($0) } let statement = PSQLExecuteStatement( name: preparedQuery.underlying.name, @@ -202,13 +186,10 @@ extension PostgresConnection: PostgresDatabase { } internal enum PostgresCommands: PostgresRequest { - case query(query: String, - binds: [PostgresData], + case query(PostgresQuery, onMetadata: (PostgresQueryMetadata) -> () = { _ in }, onRow: (PostgresRow) throws -> ()) - case queryAll(query: String, - binds: [PostgresData], - onResult: (PostgresQueryResult) -> ()) + case queryAll(PostgresQuery, onResult: (PostgresQueryResult) -> ()) case prepareQuery(request: PrepareQueryRequest) case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: (PostgresRow) throws -> ()) diff --git a/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift deleted file mode 100644 index 0c098fad..00000000 --- a/Sources/PostgresNIO/New/Data/Optional+PostgresCodable.swift +++ /dev/null @@ -1,69 +0,0 @@ -import NIOCore - -extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped.DecodableType == Wrapped { - typealias DecodableType = Wrapped - - static func decode( - from byteBuffer: inout ByteBuffer, - type: PostgresDataType, - format: PostgresFormat, - context: PostgresDecodingContext - ) throws -> Optional { - preconditionFailure("This should not be called") - } - - static func decodeRaw( - from byteBuffer: inout ByteBuffer?, - type: PostgresDataType, - format: PostgresFormat, - context: PostgresDecodingContext - ) throws -> Self { - guard var buffer = byteBuffer else { - return nil - } - return try DecodableType.decode(from: &buffer, type: type, format: format, context: context) - } -} - -extension Optional: PostgresEncodable where Wrapped: PostgresEncodable { - var psqlType: PostgresDataType { - switch self { - case .some(let value): - return value.psqlType - case .none: - return .null - } - } - - var psqlFormat: PostgresFormat { - switch self { - case .some(let value): - return value.psqlFormat - case .none: - return .binary - } - } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") - } - - func encodeRaw( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) throws { - switch self { - case .none: - byteBuffer.writeInteger(-1, as: Int32.self) - case .some(let value): - try value.encodeRaw(into: &byteBuffer, context: context) - } - } -} - -extension Optional: PostgresCodable where Wrapped: PostgresCodable, Wrapped.DecodableType == Wrapped { - -} diff --git a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift b/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift index 26eeb167..9e0c4ab0 100644 --- a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift @@ -12,7 +12,7 @@ extension PSQLRow { let swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0) } catch let code as PostgresCastingError.Code { @@ -41,13 +41,13 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1) } catch let code as PostgresCastingError.Code { @@ -76,19 +76,19 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2) } catch let code as PostgresCastingError.Code { @@ -117,25 +117,25 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3) } catch let code as PostgresCastingError.Code { @@ -164,31 +164,31 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4) } catch let code as PostgresCastingError.Code { @@ -217,37 +217,37 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5) } catch let code as PostgresCastingError.Code { @@ -276,43 +276,43 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6) } catch let code as PostgresCastingError.Code { @@ -341,49 +341,49 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 7 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T7.self - let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7) } catch let code as PostgresCastingError.Code { @@ -412,55 +412,55 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 7 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T7.self - let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 8 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T8.self - let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8) } catch let code as PostgresCastingError.Code { @@ -489,61 +489,61 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 7 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T7.self - let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 8 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T8.self - let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 9 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T9.self - let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9) } catch let code as PostgresCastingError.Code { @@ -572,67 +572,67 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 7 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T7.self - let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 8 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T8.self - let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 9 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T9.self - let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 10 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T10.self - let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10) } catch let code as PostgresCastingError.Code { @@ -661,73 +661,73 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 7 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T7.self - let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 8 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T8.self - let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 9 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T9.self - let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 10 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T10.self - let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 11 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T11.self - let r11 = try T11.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11) } catch let code as PostgresCastingError.Code { @@ -756,79 +756,79 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 7 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T7.self - let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 8 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T8.self - let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 9 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T9.self - let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 10 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T10.self - let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 11 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T11.self - let r11 = try T11.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 12 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T12.self - let r12 = try T12.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12) } catch let code as PostgresCastingError.Code { @@ -857,85 +857,85 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 7 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T7.self - let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 8 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T8.self - let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 9 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T9.self - let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 10 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T10.self - let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 11 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T11.self - let r11 = try T11.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 12 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T12.self - let r12 = try T12.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 13 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T13.self - let r13 = try T13.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13) } catch let code as PostgresCastingError.Code { @@ -964,91 +964,91 @@ extension PSQLRow { var swiftTargetType: Any.Type = T0.self do { - let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 1 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T1.self - let r1 = try T1.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 2 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T2.self - let r2 = try T2.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 3 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T3.self - let r3 = try T3.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 4 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T4.self - let r4 = try T4.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 5 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T5.self - let r5 = try T5.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 6 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T6.self - let r6 = try T6.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 7 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T7.self - let r7 = try T7.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 8 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T8.self - let r8 = try T8.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 9 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T9.self - let r9 = try T9.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 10 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T10.self - let r10 = try T10.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 11 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T11.self - let r11 = try T11.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 12 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T12.self - let r12 = try T12.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 13 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T13.self - let r13 = try T13.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) columnIndex = 14 cellData = cellIterator.next().unsafelyUnwrapped column = columnIterator.next().unsafelyUnwrapped swiftTargetType = T14.self - let r14 = try T14.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + let r14 = try T14._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14) } catch let code as PostgresCastingError.Code { diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index 8d4bcc7c..c7ae8164 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -28,7 +28,7 @@ extension PostgresCell { ) throws -> T { var copy = self.bytes do { - return try T.decodeRaw( + return try T._decodeRaw( from: ©, type: self.dataType, format: self.format, diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index d961dd08..2ae01e76 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -10,18 +10,19 @@ protocol PostgresEncodable { var psqlFormat: PostgresFormat { get } /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting - /// the byte count. This method is called from the default `encodeRaw` implementation. + /// the byte count. This method is called from the ``PostgresBindings``. func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) throws - - /// Encode the entity into the `byteBuffer` in Postgres binary format including its - /// leading byte count. This method has a default implementation and may be overriden - /// only for special cases, like `Optional`s. - func encodeRaw(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) throws } /// A type that can decode itself from a postgres wire binary representation. +/// +/// If you want to conform a type to PostgresDecodable you must implement the decode method. protocol PostgresDecodable { - associatedtype DecodableType: PostgresDecodable = Self + /// A type definition of the type that actually implements the PostgresDecodable protocol. This is an escape hatch to + /// prevent a cycle in the conformace of the Optional type to PostgresDecodable. + /// + /// String? should be PostgresDecodable, String?? should not be PostgresDecodable + associatedtype _DecodableType: PostgresDecodable = Self /// Decode an entity from the `byteBuffer` in postgres wire format /// @@ -41,10 +42,9 @@ protocol PostgresDecodable { context: PostgresDecodingContext ) throws -> Self - /// Decode an entity from the `byteBuffer` in postgres wire format. - /// This method has a default implementation and may be overriden - /// only for special cases, like `Optional`s. - static func decodeRaw( + /// Decode an entity from the `byteBuffer` in postgres wire format. This method has a default implementation and + /// is only overwritten for `Optional`s. Other than in the + static func _decodeRaw( from byteBuffer: inout ByteBuffer?, type: PostgresDataType, format: PostgresFormat, @@ -54,7 +54,7 @@ protocol PostgresDecodable { extension PostgresDecodable { @inlinable - static func decodeRaw( + static func _decodeRaw( from byteBuffer: inout ByteBuffer?, type: PostgresDataType, format: PostgresFormat, @@ -108,3 +108,29 @@ struct PostgresDecodingContext { self.jsonDecoder = jsonDecoder } } + +extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { + static let `default` = PostgresDecodingContext(jsonDecoder: Foundation.JSONDecoder()) +} + +extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped._DecodableType == Wrapped { + typealias _DecodableType = Wrapped + + static func decode(from byteBuffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext) throws -> Optional { + preconditionFailure("This should not be called") + } + + static func _decodeRaw( + from byteBuffer: inout ByteBuffer?, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws -> Optional { + switch byteBuffer { + case .some(var buffer): + return try Wrapped.decode(from: &buffer, type: type, format: format, context: context) + case .none: + return .none + } + } +} diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 62a74cce..b1f00f0a 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -22,13 +22,6 @@ extension PostgresQuery: ExpressibleByStringInterpolation { self.sql = value self.binds = PostgresBindings() } - - mutating func appendBinding( - _ value: Value, - context: PostgresEncodingContext - ) throws { - try self.binds.append(value, context: context) - } } extension PostgresQuery { @@ -53,7 +46,13 @@ extension PostgresQuery { } mutating func appendInterpolation(_ value: Optional) throws { - try self.binds.append(value, context: .default) + switch value { + case .none: + self.binds.appendNull() + case .some(let value): + try self.binds.append(value, context: .default) + } + self.sql.append(contentsOf: "$\(self.binds.count)") } @@ -110,6 +109,22 @@ struct PostgresBindings: Hashable { self.bytes.reserveCapacity(128 * capacity) } + mutating func appendNull() { + self.bytes.writeInteger(-1, as: Int32.self) + self.metadata.append(.init(dataType: .null, format: .binary)) + } + + mutating func append(_ postgresData: PostgresData) { + switch postgresData.value { + case .none: + self.bytes.writeInteger(-1, as: Int32.self) + case .some(var input): + self.bytes.writeInteger(Int32(input.readableBytes)) + self.bytes.writeBuffer(&input) + } + self.metadata.append(.init(dataType: postgresData.type, format: .binary)) + } + mutating func append( _ value: Value, context: PostgresEncodingContext diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index f2d97112..8c7e7db1 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -1,51 +1,5 @@ import NIOCore -extension PostgresData: PostgresEncodable { - var psqlType: PostgresDataType { - self.type - } - - var psqlFormat: PostgresFormat { - .binary - } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) throws { - preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") - } - - // encoding - func encodeRaw( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - switch self.value { - case .none: - byteBuffer.writeInteger(-1, as: Int32.self) - case .some(var input): - byteBuffer.writeInteger(Int32(input.readableBytes)) - byteBuffer.writeBuffer(&input) - } - } -} - -extension PostgresData: PostgresDecodable { - static func decode( - from buffer: inout ByteBuffer, - type: PostgresDataType, - format: PostgresFormat, - context: PostgresDecodingContext - ) throws -> Self { - let myBuffer = buffer.readSlice(length: buffer.readableBytes)! - - return PostgresData(type: PostgresDataType(UInt32(type.rawValue)), typeModifier: nil, formatCode: .binary, value: myBuffer) - } -} - -extension PostgresData: PostgresCodable {} - extension PSQLError { func toPostgresError() -> Error { switch self.base { diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index b6c0b183..95abb6fc 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -23,7 +23,10 @@ extension PostgresDatabase { onMetadata: @escaping (PostgresQueryMetadata) -> () = { _ in }, onRow: @escaping (PostgresRow) throws -> () ) -> EventLoopFuture { - let request = PostgresCommands.query(query: string, binds: binds, onMetadata: onMetadata, onRow: onRow) + var bindings = PostgresBindings(capacity: binds.count) + binds.forEach { bindings.append($0) } + let query = PostgresQuery(unsafeSQL: string, binds: bindings) + let request = PostgresCommands.query(query, onMetadata: onMetadata, onRow: onRow) return self.send(request, logger: logger) } diff --git a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift deleted file mode 100644 index 1d689b0e..00000000 --- a/Tests/PostgresNIOTests/New/Data/Optional+PSQLCodableTests.swift +++ /dev/null @@ -1,67 +0,0 @@ -import XCTest -import NIOCore -@testable import PostgresNIO - -class Optional_PSQLCodableTests: XCTestCase { - - func testRoundTripSomeString() { - let value: String? = "Hello World" - - var buffer = ByteBuffer() - XCTAssertNoThrow(try value.encodeRaw(into: &buffer, context: .forTests())) - XCTAssertEqual(value.psqlType, .text) - XCTAssertEqual(buffer.readInteger(as: Int32.self), 11) - - var result: String? - var optBuffer: ByteBuffer? = buffer - XCTAssertNoThrow(result = try String?.decodeRaw(from: &optBuffer, type: .text, format: .binary, context: .forTests())) - XCTAssertEqual(result, value) - } - - func testRoundTripNoneString() { - let value: Optional = .none - - var buffer = ByteBuffer() - XCTAssertNoThrow(try value.encodeRaw(into: &buffer, context: .forTests())) - XCTAssertEqual(buffer.readableBytes, 4) - XCTAssertEqual(buffer.getInteger(at: 0, as: Int32.self), -1) - XCTAssertEqual(value.psqlType, .null) - - var result: String? - var inBuffer: ByteBuffer? = nil - XCTAssertNoThrow(result = try String?.decodeRaw(from: &inBuffer, type: .text, format: .binary, context: .forTests())) - XCTAssertEqual(result, value) - } - - func testRoundTripSomeUUIDAsPSQLEncodable() { - let value: Optional = UUID() - let encodable: PostgresEncodable = value - - var buffer = ByteBuffer() - XCTAssertEqual(encodable.psqlType, .uuid) - XCTAssertNoThrow(try encodable.encodeRaw(into: &buffer, context: .forTests())) - XCTAssertEqual(buffer.readableBytes, 20) - XCTAssertEqual(buffer.readInteger(as: Int32.self), 16) - - var result: UUID? - var optBuffer: ByteBuffer? = buffer - XCTAssertNoThrow(result = try UUID?.decodeRaw(from: &optBuffer, type: .uuid, format: .binary, context: .forTests())) - XCTAssertEqual(result, value) - } - - func testRoundTripNoneUUIDAsPSQLEncodable() { - let value: Optional = .none - let encodable: PostgresEncodable = value - - var buffer = ByteBuffer() - XCTAssertEqual(encodable.psqlType, .null) - XCTAssertNoThrow(try encodable.encodeRaw(into: &buffer, context: .forTests())) - XCTAssertEqual(buffer.readableBytes, 4) - XCTAssertEqual(buffer.readInteger(as: Int32.self), -1) - - var result: UUID? - var inBuffer: ByteBuffer? = nil - XCTAssertNoThrow(result = try UUID?.decodeRaw(from: &inBuffer, type: .text, format: .binary, context: .forTests())) - XCTAssertEqual(result, value) - } -} diff --git a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift new file mode 100644 index 00000000..bf300c1f --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift @@ -0,0 +1,64 @@ +import XCTest +@testable import PostgresNIO + +final class PostgresCodableTests: XCTestCase { + + func testDecodeAnOptionalFromARow() { + let row = PSQLRow( + data: .makeTestDataRow(nil, ByteBuffer(string: "Hello world!")), + lookupTable: ["id": 0, "name": 1], + columns: [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ), + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + ) + + var result: (String?, String?) + XCTAssertNoThrow(result = try row.decode((String?, String?).self, context: .default)) + XCTAssertNil(result.0) + XCTAssertEqual(result.1, "Hello world!") + } + + func testDecodeMissingValueError() { + let row = PSQLRow( + data: .makeTestDataRow(nil), + lookupTable: ["name": 0], + columns: [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + ) + + XCTAssertThrowsError(try row.decode(String.self, context: .default)) { + XCTAssertEqual(($0 as? PostgresCastingError)?.line, #line - 1) + XCTAssertEqual(($0 as? PostgresCastingError)?.file, #file) + + XCTAssertEqual(($0 as? PostgresCastingError)?.code, .missingData) + XCTAssert(($0 as? PostgresCastingError)?.targetType == String.self) + XCTAssertEqual(($0 as? PostgresCastingError)?.postgresType, .text) + } + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift index 80f52ea5..43c39a3a 100644 --- a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -66,7 +66,7 @@ final class PostgresQueryTests: XCTestCase { var query = PostgresQuery(unsafeSQL: sql, binds: .init(capacity: 5)) for value in 1...5 { - XCTAssertNoThrow(try query.appendBinding(Int(value), context: .default)) + XCTAssertNoThrow(try query.binds.append(Int(value), context: .default)) } XCTAssertEqual(query.sql, "INSERT INTO test (id) SET ($1, $2, $3, $4, $5);") diff --git a/dev/generate-psqlrow-multi-decode.sh b/dev/generate-psqlrow-multi-decode.sh index f2be1ad1..5fee4a93 100755 --- a/dev/generate-psqlrow-multi-decode.sh +++ b/dev/generate-psqlrow-multi-decode.sh @@ -49,14 +49,14 @@ function gen() { echo echo " do {" - echo " let r0 = try T0.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context)" + echo " let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context)" echo for ((n = 1; n<$how_many; n +=1)); do echo " columnIndex = $n" echo " cellData = cellIterator.next().unsafelyUnwrapped" echo " column = columnIterator.next().unsafelyUnwrapped" echo " swiftTargetType = T$n.self" - echo " let r$n = try T$n.decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context)" + echo " let r$n = try T$n._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context)" echo done From 2938198672124f5fe0845d95119878e5389e5eef Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 26 Feb 2022 00:09:43 +0100 Subject: [PATCH 065/292] Merge PSQLRow into PostgresRow (#219) --- .../Connection/PostgresConnection.swift | 59 +-- .../PostgresDatabase+PreparedQuery.swift | 19 - Sources/PostgresNIO/Data/PostgresRow.swift | 336 +++++++++++++++--- .../New/PSQLRow-multi-decode.swift | 17 +- Sources/PostgresNIO/New/PSQLRow.swift | 70 ---- Sources/PostgresNIO/New/PSQLRowStream.swift | 26 +- Sources/PostgresNIO/New/PostgresCell.swift | 12 +- .../PostgresNIO/New/PostgresRowSequence.swift | 28 +- .../PSQLIntegrationTests.swift | 208 +++++------ Tests/IntegrationTests/PerformanceTests.swift | 80 +++-- Tests/IntegrationTests/PostgresNIOTests.swift | 295 ++++++++------- .../New/PSQLRowStreamTests.swift | 10 +- .../New/PostgresCodableTests.swift | 4 +- .../New/PostgresRowSequenceTests.swift | 12 +- .../New/PostgresRowTests.swift | 124 +++++++ dev/generate-psqlrow-multi-decode.sh | 4 +- 16 files changed, 775 insertions(+), 529 deletions(-) delete mode 100644 Sources/PostgresNIO/New/PSQLRow.swift create mode 100644 Tests/PostgresNIOTests/New/PostgresRowTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index fcb72953..0962482d 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -108,47 +108,15 @@ extension PostgresConnection: PostgresDatabase { switch command { case .query(let query, let onMetadata, let onRow): resultFuture = self.underlying.query(query, logger: logger).flatMap { stream in - let fields = stream.rowDescription.map { column in - PostgresMessage.RowDescription.Field( - name: column.name, - tableOID: UInt32(column.tableOID), - columnAttributeNumber: column.columnAttributeNumber, - dataType: PostgresDataType(UInt32(column.dataType.rawValue)), - dataTypeSize: column.dataTypeSize, - dataTypeModifier: column.dataTypeModifier, - formatCode: .init(psqlFormatCode: column.format) - ) - } - - let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) - return stream.iterateRowsWithoutBackpressureOption(lookupTable: lookupTable, onRow: onRow).map { _ in + return stream.onRow(onRow).map { _ in onMetadata(PostgresQueryMetadata(string: stream.commandTag)!) } } + case .queryAll(let query, let onResult): resultFuture = self.underlying.query(query, logger: logger).flatMap { rows in - let fields = rows.rowDescription.map { column in - PostgresMessage.RowDescription.Field( - name: column.name, - tableOID: UInt32(column.tableOID), - columnAttributeNumber: column.columnAttributeNumber, - dataType: PostgresDataType(UInt32(column.dataType.rawValue)), - dataTypeSize: column.dataTypeSize, - dataTypeModifier: column.dataTypeModifier, - formatCode: .init(psqlFormatCode: column.format) - ) - } - - let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) return rows.all().map { allrows in - let r = allrows.map { psqlRow -> PostgresRow in - let columns = psqlRow.data.map { - PostgresMessage.DataRow.Column(value: $0) - } - return PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) - } - - onResult(.init(metadata: PostgresQueryMetadata(string: rows.commandTag)!, rows: r)) + onResult(.init(metadata: PostgresQueryMetadata(string: rows.commandTag)!, rows: allrows)) } } @@ -156,6 +124,7 @@ extension PostgresConnection: PostgresDatabase { resultFuture = self.underlying.prepareStatement(request.query, with: request.name, logger: self.logger).map { request.prepared = PreparedQuery(underlying: $0, database: self) } + case .executePreparedStatement(let preparedQuery, let binds, let onRow): var bindings = PostgresBindings(capacity: binds.count) binds.forEach { bindings.append($0) } @@ -167,11 +136,7 @@ extension PostgresConnection: PostgresDatabase { ) resultFuture = self.underlying.execute(statement, logger: logger).flatMap { rows in - guard let lookupTable = preparedQuery.lookupTable else { - return self.eventLoop.makeSucceededFuture(()) - } - - return rows.iterateRowsWithoutBackpressureOption(lookupTable: lookupTable, onRow: onRow) + return rows.onRow(onRow) } } @@ -206,20 +171,6 @@ internal enum PostgresCommands: PostgresRequest { } } -extension PSQLRowStream { - - func iterateRowsWithoutBackpressureOption(lookupTable: PostgresRow.LookupTable, onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { - self.onRow { psqlRow in - let columns = psqlRow.data.map { - PostgresMessage.DataRow.Column(value: $0) - } - - let row = PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) - try onRow(row) - } - } -} - // MARK: Notifications /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift index cf315b19..074ba6de 100644 --- a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift +++ b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift @@ -28,29 +28,10 @@ extension PostgresDatabase { public struct PreparedQuery { let underlying: PSQLPreparedStatement - let lookupTable: PostgresRow.LookupTable? let database: PostgresDatabase init(underlying: PSQLPreparedStatement, database: PostgresDatabase) { self.underlying = underlying - self.lookupTable = underlying.rowDescription.flatMap { - rowDescription -> PostgresRow.LookupTable in - - let fields = rowDescription.columns.map { column in - PostgresMessage.RowDescription.Field( - name: column.name, - tableOID: UInt32(column.tableOID), - columnAttributeNumber: column.columnAttributeNumber, - dataType: PostgresDataType(UInt32(column.dataType.rawValue)), - dataTypeSize: column.dataTypeSize, - dataTypeModifier: column.dataTypeModifier, - formatCode: .init(psqlFormatCode: column.format) - ) - } - - return .init(rowDescription: .init(fields: fields), resultFormat: [.binary]) - } - self.database = database } diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 7b08b360..3ac20c5e 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -1,76 +1,312 @@ -public struct PostgresRow: CustomStringConvertible { - final class LookupTable { - let rowDescription: PostgresMessage.RowDescription - let resultFormat: [PostgresFormat] - - struct Value { - let index: Int - let field: PostgresMessage.RowDescription.Field +import NIOCore +import class Foundation.JSONDecoder + +/// `PostgresRow` represents a single table row that is received from the server for a query or a prepared statement. +/// Its element type is ``PostgresCell``. +/// +/// - Warning: Please note that random access to cells in a ``PostgresRow`` have O(n) time complexity. If you require +/// random access to cells in O(1) create a new ``PostgresRandomAccessRow`` with the given row and +/// access it instead. +public struct PostgresRow { + let lookupTable: [String: Int] + let data: DataRow + + let columns: [RowDescription.Column] + + init(data: DataRow, lookupTable: [String: Int], columns: [RowDescription.Column]) { + self.data = data + self.lookupTable = lookupTable + self.columns = columns + } +} + +extension PostgresRow: Equatable { + public static func ==(lhs: Self, rhs: Self) -> Bool { + // we don't need to compare the lookup table here, as the looup table is only derived + // from the column description. + lhs.data == rhs.data && lhs.columns == rhs.columns + } +} + +extension PostgresRow: Sequence { + public typealias Element = PostgresCell + + public struct Iterator: IteratorProtocol { + public typealias Element = PostgresCell + + private(set) var columnIndex: Array.Index + private(set) var columnIterator: Array.Iterator + private(set) var dataIterator: DataRow.Iterator + + init(_ row: PostgresRow) { + self.columnIndex = 0 + self.columnIterator = row.columns.makeIterator() + self.dataIterator = row.data.makeIterator() } - - private var _storage: [String: Value]? - var storage: [String: Value] { - if let existing = self._storage { - return existing - } else { - let all = self.rowDescription.fields.enumerated().map { (index, field) in - return (field.name, Value(index: index, field: field)) - } - let storage = [String: Value](all) { a, b in - // take the first value - return a - } - self._storage = storage - return storage + + public mutating func next() -> PostgresCell? { + guard let bytes = self.dataIterator.next() else { + return nil } + + let column = self.columnIterator.next()! + + defer { self.columnIndex += 1 } + + return PostgresCell( + bytes: bytes, + dataType: column.dataType, + format: column.format, + columnName: column.name, + columnIndex: columnIndex + ) } + } + + public func makeIterator() -> Iterator { + Iterator(self) + } +} - init( - rowDescription: PostgresMessage.RowDescription, - resultFormat: [PostgresFormat] - ) { - self.rowDescription = rowDescription - self.resultFormat = resultFormat +extension PostgresRow: Collection { + public struct Index: Comparable { + var cellIndex: DataRow.Index + var columnIndex: Array.Index + + // Only needed implementation for comparable. The compiler synthesizes the rest from this. + public static func < (lhs: Self, rhs: Self) -> Bool { + lhs.columnIndex < rhs.columnIndex } + } - func lookup(column: String) -> Value? { - if let value = self.storage[column] { - return value - } else { - return nil - } + public subscript(position: Index) -> PostgresCell { + let column = self.columns[position.columnIndex] + return PostgresCell( + bytes: self.data[position.cellIndex], + dataType: column.dataType, + format: column.format, + columnName: column.name, + columnIndex: position.columnIndex + ) + } + + public var startIndex: Index { + Index( + cellIndex: self.data.startIndex, + columnIndex: 0 + ) + } + + public var endIndex: Index { + Index( + cellIndex: self.data.endIndex, + columnIndex: self.columns.count + ) + } + + public func index(after i: Index) -> Index { + Index( + cellIndex: self.data.index(after: i.cellIndex), + columnIndex: self.columns.index(after: i.columnIndex) + ) + } + + public var count: Int { + self.data.count + } +} + +extension PostgresRow { + public func makeRandomAccess() -> PostgresRandomAccessRow { + PostgresRandomAccessRow(self) + } +} + +/// A random access row of ``PostgresCell``s. Its initialization is O(n) where n is the number of columns +/// in the row. All subsequent cell access are O(1). +public struct PostgresRandomAccessRow { + let columns: [RowDescription.Column] + let cells: [ByteBuffer?] + let lookupTable: [String: Int] + + init(_ row: PostgresRow) { + self.cells = [ByteBuffer?](row.data) + self.columns = row.columns + self.lookupTable = row.lookupTable + } +} + +extension PostgresRandomAccessRow: RandomAccessCollection { + public typealias Element = PostgresCell + public typealias Index = Int + + public var startIndex: Int { + 0 + } + + public var endIndex: Int { + self.columns.count + } + + public var count: Int { + self.columns.count + } + + public subscript(index: Int) -> PostgresCell { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + let column = self.columns[index] + return PostgresCell( + bytes: self.cells[index], + dataType: column.dataType, + format: column.format, + columnName: column.name, + columnIndex: index + ) + } + + public subscript(name: String) -> PostgresCell { + guard let index = self.lookupTable[name] else { + fatalError(#"A column "\#(name)" does not exist."#) + } + return self[index] + } +} + +extension PostgresRandomAccessRow { + public subscript(data index: Int) -> PostgresData { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + let column = self.columns[index] + return PostgresData( + type: column.dataType, + typeModifier: column.dataTypeModifier, + formatCode: .binary, + value: self.cells[index] + ) + } + + public subscript(data column: String) -> PostgresData { + guard let index = self.lookupTable[column] else { + fatalError(#"A column "\#(column)" does not exist."#) } + return self[data: index] } +} - public let dataRow: PostgresMessage.DataRow +extension PostgresRandomAccessRow { + /// Access the data in the provided column and decode it into the target type. + /// + /// - Parameters: + /// - column: The column name to read the data from + /// - type: The type to decode the data into + /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. + /// - Returns: The decoded value of Type T. + func decode( + column: String, + as type: T.Type, + context: PostgresDecodingContext, + file: String = #file, line: Int = #line + ) throws -> T { + guard let index = self.lookupTable[column] else { + fatalError(#"A column "\#(column)" does not exist."#) + } + return try self.decode(column: index, as: type, context: context, file: file, line: line) + } + + /// Access the data in the provided column and decode it into the target type. + /// + /// - Parameters: + /// - column: The column index to read the data from + /// - type: The type to decode the data into + /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. + /// - Returns: The decoded value of Type T. + func decode( + column index: Int, + as type: T.Type, + context: PostgresDecodingContext, + file: String = #file, line: Int = #line + ) throws -> T { + precondition(index < self.columns.count) + + let column = self.columns[index] + + var cellSlice = self.cells[index] + do { + return try T._decodeRaw(from: &cellSlice, type: column.dataType, format: column.format, context: context) + } catch let code as PostgresCastingError.Code { + throw PostgresCastingError( + code: code, + columnName: self.columns[index].name, + columnIndex: index, + targetType: T.self, + postgresType: self.columns[index].dataType, + postgresFormat: self.columns[index].format, + postgresData: cellSlice, + file: file, + line: line + ) + } + } +} + +// MARK: Deprecated API + +extension PostgresRow { public var rowDescription: PostgresMessage.RowDescription { - self.lookupTable.rowDescription + let fields = self.columns.map { column in + PostgresMessage.RowDescription.Field( + name: column.name, + tableOID: UInt32(column.tableOID), + columnAttributeNumber: column.columnAttributeNumber, + dataType: PostgresDataType(UInt32(column.dataType.rawValue)), + dataTypeSize: column.dataTypeSize, + dataTypeModifier: column.dataTypeModifier, + formatCode: .init(psqlFormatCode: column.format) + ) + } + return PostgresMessage.RowDescription(fields: fields) } - let lookupTable: LookupTable + public var dataRow: PostgresMessage.DataRow { + let columns = self.data.map { + PostgresMessage.DataRow.Column(value: $0) + } + return PostgresMessage.DataRow(columns: columns) + } + @available(*, deprecated, message: """ + This call is O(n) where n is the number of cells in the row. For random access to cells + in a row create a PostgresRandomAccessCollection from the row first and use its subscript + methods. + """) public func column(_ column: String) -> PostgresData? { - guard let entry = self.lookupTable.lookup(column: column) else { + guard let index = self.lookupTable[column] else { return nil } - let formatCode: PostgresFormat - switch self.lookupTable.resultFormat.count { - case 1: formatCode = self.lookupTable.resultFormat[0] - default: formatCode = entry.field.formatCode - } + return PostgresData( - type: entry.field.dataType, - typeModifier: entry.field.dataTypeModifier, - formatCode: formatCode, - value: self.dataRow.columns[entry.index].value + type: self.columns[index].dataType, + typeModifier: self.columns[index].dataTypeModifier, + formatCode: .binary, + value: self.data[column: index] ) } +} +extension PostgresRow: CustomStringConvertible { public var description: String { var row: [String: PostgresData] = [:] - for field in self.lookupTable.rowDescription.fields { - row[field.name] = self.column(field.name) + for cell in self { + row[cell.columnName] = PostgresData( + type: cell.dataType, + typeModifier: 0, + formatCode: cell.format, + value: cell.bytes + ) } return row.description } diff --git a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift b/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift index 9e0c4ab0..ef67c7ac 100644 --- a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift @@ -1,7 +1,6 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-psqlrow-multi-decode.sh -extension PSQLRow { - @inlinable +extension PostgresRow { func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0) { precondition(self.columns.count >= 1) let columnIndex = 0 @@ -30,7 +29,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1) { precondition(self.columns.count >= 2) var columnIndex = 0 @@ -65,7 +63,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { precondition(self.columns.count >= 3) var columnIndex = 0 @@ -106,7 +103,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { precondition(self.columns.count >= 4) var columnIndex = 0 @@ -153,7 +149,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { precondition(self.columns.count >= 5) var columnIndex = 0 @@ -206,7 +201,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { precondition(self.columns.count >= 6) var columnIndex = 0 @@ -265,7 +259,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { precondition(self.columns.count >= 7) var columnIndex = 0 @@ -330,7 +323,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { precondition(self.columns.count >= 8) var columnIndex = 0 @@ -401,7 +393,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { precondition(self.columns.count >= 9) var columnIndex = 0 @@ -478,7 +469,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { precondition(self.columns.count >= 10) var columnIndex = 0 @@ -561,7 +551,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { precondition(self.columns.count >= 11) var columnIndex = 0 @@ -650,7 +639,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { precondition(self.columns.count >= 12) var columnIndex = 0 @@ -745,7 +733,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { precondition(self.columns.count >= 13) var columnIndex = 0 @@ -846,7 +833,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { precondition(self.columns.count >= 14) var columnIndex = 0 @@ -953,7 +939,6 @@ extension PSQLRow { } } - @inlinable func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { precondition(self.columns.count >= 15) var columnIndex = 0 diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift deleted file mode 100644 index 91389538..00000000 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ /dev/null @@ -1,70 +0,0 @@ -import NIOCore -import Foundation - -/// `PSQLRow` represents a single row that was received from the Postgres Server. -struct PSQLRow { - internal let lookupTable: [String: Int] - internal let data: DataRow - - internal let columns: [RowDescription.Column] - - internal init(data: DataRow, lookupTable: [String: Int], columns: [RowDescription.Column]) { - self.data = data - self.lookupTable = lookupTable - self.columns = columns - } -} - -extension PSQLRow: Equatable { - static func ==(lhs: Self, rhs: Self) -> Bool { - lhs.data == rhs.data && lhs.columns == rhs.columns - } -} - -extension PSQLRow { - /// Access the data in the provided column and decode it into the target type. - /// - /// - Parameters: - /// - column: The column name to read the data from - /// - type: The type to decode the data into - /// - Throws: The error of the decoding implementation. See also `PostgresDecodable` protocol for this. - /// - Returns: The decoded value of Type T. - func decode(column: String, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { - guard let index = self.lookupTable[column] else { - preconditionFailure("A column '\(column)' does not exist.") - } - - return try self.decode(column: index, as: type, jsonDecoder: jsonDecoder, file: file, line: line) - } - - /// Access the data in the provided column and decode it into the target type. - /// - /// - Parameters: - /// - column: The column index to read the data from - /// - type: The type to decode the data into - /// - Throws: The error of the decoding implementation. See also `PostgresDecodable` protocol for this. - /// - Returns: The decoded value of Type T. - func decode(column index: Int, as type: T.Type, jsonDecoder: JSONDecoder, file: String = #file, line: Int = #line) throws -> T { - precondition(index < self.data.columnCount) - - let column = self.columns[index] - let context = PostgresDecodingContext(jsonDecoder: jsonDecoder) - - // Safe to force unwrap here, as we have ensured above that the row has enough columns - var cellSlice = self.data[column: index]! - - return try T.decode(from: &cellSlice, type: column.dataType, format: column.format, context: context) - } -} - -extension PSQLRow { - // TODO: Remove this function. Only here to keep the tests running as of today. - func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { - try self.decode(column: column, as: type, jsonDecoder: JSONDecoder(), file: file, line: line) - } - - // TODO: Remove this function. Only here to keep the tests running as of today. - func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { - try self.decode(column: index, as: type, jsonDecoder: JSONDecoder(), file: file, line: line) - } -} diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index e69d219b..787c6cef 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -18,8 +18,8 @@ final class PSQLRowStream { private enum DownstreamState { case waitingForConsumer(BufferState) - case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) - case waitingForAll([PSQLRow], EventLoopPromise<[PSQLRow]>, PSQLRowsDataSource) + case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) + case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) case consumed(Result) #if swift(>=5.5) && canImport(_Concurrency) @@ -145,7 +145,7 @@ final class PSQLRowStream { // MARK: Consume in array - func all() -> EventLoopFuture<[PSQLRow]> { + func all() -> EventLoopFuture<[PostgresRow]> { if self.eventLoop.inEventLoop { return self.all0() } else { @@ -155,7 +155,7 @@ final class PSQLRowStream { } } - private func all0() -> EventLoopFuture<[PSQLRow]> { + private func all0() -> EventLoopFuture<[PostgresRow]> { self.eventLoop.preconditionInEventLoop() guard case .waitingForConsumer(let bufferState) = self.downstreamState else { @@ -164,9 +164,9 @@ final class PSQLRowStream { switch bufferState { case .streaming(let bufferedRows, let dataSource): - let promise = self.eventLoop.makePromise(of: [PSQLRow].self) + let promise = self.eventLoop.makePromise(of: [PostgresRow].self) let rows = bufferedRows.map { data in - PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription) + PostgresRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription) } self.downstreamState = .waitingForAll(rows, promise, dataSource) // immediately request more @@ -175,7 +175,7 @@ final class PSQLRowStream { case .finished(let buffer, let commandTag): let rows = buffer.map { - PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) + PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) } self.downstreamState = .consumed(.success(commandTag)) @@ -189,7 +189,7 @@ final class PSQLRowStream { // MARK: Consume on EventLoop - func onRow(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture { + func onRow(_ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { if self.eventLoop.inEventLoop { return self.onRow0(onRow) } else { @@ -199,7 +199,7 @@ final class PSQLRowStream { } } - private func onRow0(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture { + private func onRow0(_ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { self.eventLoop.preconditionInEventLoop() guard case .waitingForConsumer(let bufferState) = self.downstreamState else { @@ -211,7 +211,7 @@ final class PSQLRowStream { let promise = self.eventLoop.makePromise(of: Void.self) do { for data in buffer { - let row = PSQLRow( + let row = PostgresRow( data: data, lookupTable: self.lookupTable, columns: self.rowDescription @@ -234,7 +234,7 @@ final class PSQLRowStream { case .finished(let buffer, let commandTag): do { for data in buffer { - let row = PSQLRow( + let row = PostgresRow( data: data, lookupTable: self.lookupTable, columns: self.rowDescription @@ -279,7 +279,7 @@ final class PSQLRowStream { case .iteratingRows(let onRow, let promise, let dataSource): do { for data in newRows { - let row = PSQLRow( + let row = PostgresRow( data: data, lookupTable: self.lookupTable, columns: self.rowDescription @@ -297,7 +297,7 @@ final class PSQLRowStream { case .waitingForAll(var rows, let promise, let dataSource): newRows.forEach { data in - let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription) + let row = PostgresRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription) rows.append(row) } self.downstreamState = .waitingForAll(rows, promise, dataSource) diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index c7ae8164..a29eacd6 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -1,12 +1,12 @@ import NIOCore -struct PostgresCell: Equatable { - var bytes: ByteBuffer? - var dataType: PostgresDataType - var format: PostgresFormat +public struct PostgresCell: Equatable { + public var bytes: ByteBuffer? + public var dataType: PostgresDataType + public var format: PostgresFormat - var columnName: String - var columnIndex: Int + public var columnName: String + public var columnIndex: Int init(bytes: ByteBuffer?, dataType: PostgresDataType, format: PostgresFormat, columnName: String, columnIndex: Int) { self.bytes = bytes diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 160cea02..0a7765a5 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -6,7 +6,7 @@ import NIOConcurrencyHelpers /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. struct PostgresRowSequence: AsyncSequence { - typealias Element = PSQLRow + typealias Element = PostgresRow typealias AsyncIterator = Iterator final class _Internal { @@ -40,7 +40,7 @@ struct PostgresRowSequence: AsyncSequence { extension PostgresRowSequence { struct Iterator: AsyncIteratorProtocol { - typealias Element = PSQLRow + typealias Element = PostgresRow let _internal: _Internal @@ -48,7 +48,7 @@ extension PostgresRowSequence { self._internal = _Internal(consumer: consumer) } - mutating func next() async throws -> PSQLRow? { + mutating func next() async throws -> PostgresRow? { try await self._internal.next() } @@ -63,7 +63,7 @@ extension PostgresRowSequence { self.consumer.iteratorDeinitialized() } - func next() async throws -> PSQLRow? { + func next() async throws -> PostgresRow? { try await self.consumer.next() } } @@ -112,7 +112,7 @@ final class AsyncStreamConsumer { switch receiveAction { case .succeed(let continuation, let data, signalDemandTo: let source): - let row = PSQLRow( + let row = PostgresRow( data: data, lookupTable: self.lookupTable, columns: self.columns @@ -176,7 +176,7 @@ final class AsyncStreamConsumer { } } - func next() async throws -> PSQLRow? { + func next() async throws -> PostgresRow? { self.lock.lock() switch self.state.next() { case .returnNil: @@ -186,7 +186,7 @@ final class AsyncStreamConsumer { case .returnRow(let data, signalDemandTo: let source): self.lock.unlock() source?.demand() - return PSQLRow( + return PostgresRow( data: data, lookupTable: self.lookupTable, columns: self.columns @@ -217,7 +217,7 @@ extension AsyncStreamConsumer { enum UpstreamState { enum DemandState { case canAskForMore - case waitingForMore(CheckedContinuation?) + case waitingForMore(CheckedContinuation?) } case initialized @@ -397,7 +397,7 @@ extension AsyncStreamConsumer { case none } - mutating func next(for continuation: CheckedContinuation) -> NextSlowPathAction { + mutating func next(for continuation: CheckedContinuation) -> NextSlowPathAction { switch self.upstreamState { case .initialized: preconditionFailure() @@ -424,7 +424,7 @@ extension AsyncStreamConsumer { } enum ReceiveAction { - case succeed(CheckedContinuation, DataRow, signalDemandTo: PSQLRowStream?) + case succeed(CheckedContinuation, DataRow, signalDemandTo: PSQLRowStream?) case none } @@ -464,8 +464,8 @@ extension AsyncStreamConsumer { } enum CompletionResult { - case succeed(CheckedContinuation) - case fail(CheckedContinuation, Error) + case succeed(CheckedContinuation) + case fail(CheckedContinuation, Error) case none } @@ -533,8 +533,8 @@ extension AsyncStreamConsumer { } extension PostgresRowSequence { - func collect() async throws -> [PSQLRow] { - var result = [PSQLRow]() + func collect() async throws -> [PostgresRow] { + var result = [PostgresRow]() for try await row in self { result.append(row) } diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 61bdb136..16d720f7 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -11,17 +11,17 @@ final class IntegrationTests: XCTestCase { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) XCTAssertNoThrow(try conn?.close().wait()) } - + func testAuthenticationFailure() throws { // If the postgres server trusts every connection, it is really hard to create an // authentication failure. try XCTSkipIf(env("POSTGRES_HOST_AUTH_METHOD") == "trust") - + let config = PSQLConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432, @@ -29,115 +29,115 @@ final class IntegrationTests: XCTestCase { database: env("POSTGRES_DB") ?? "test_database", password: "wrong_password", tlsConfiguration: nil) - + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } - + var logger = Logger.psqlTest logger.logLevel = .info - + var connection: PSQLConnection? XCTAssertThrowsError(connection = try PSQLConnection.connect(configuration: config, logger: logger, on: eventLoopGroup.next()).wait()) { XCTAssertTrue($0 is PSQLError) } - + // In case of a test failure the created connection must be closed. XCTAssertNoThrow(try connection?.close().wait()) } - + func testQueryVersion() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT version()", logger: .psqlTest).wait()) - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var version: String? - XCTAssertNoThrow(version = try rows?.first?.decode(column: 0, as: String.self)) + XCTAssertNoThrow(version = try rows?.first?.decode(String.self, context: .default)) XCTAssertEqual(version?.contains("PostgreSQL"), true) } - + func testQuery10kItems() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest).wait()) - + var received: Int64 = 0 - + XCTAssertNoThrow(try stream?.onRow { row in func workaround() { var number: Int64? - XCTAssertNoThrow(number = try row.decode(column: 0, as: Int64.self)) + XCTAssertNoThrow(number = try row.decode(Int64.self, context: .default)) received += 1 XCTAssertEqual(number, received) } - + workaround() }.wait()) - + XCTAssertEqual(received, 10000) } - + func test1kRoundTrips() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + for _ in 0..<1_000 { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT version()", logger: .psqlTest).wait()) - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var version: String? - XCTAssertNoThrow(version = try rows?.first?.decode(column: 0, as: String.self)) + XCTAssertNoThrow(version = try rows?.first?.decode(String.self, context: .default)) XCTAssertEqual(version?.contains("PostgreSQL"), true) } } - + func testQuerySelectParameter() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT \("hello")::TEXT as foo", logger: .psqlTest).wait()) - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var foo: String? - XCTAssertNoThrow(foo = try rows?.first?.decode(column: 0, as: String.self)) + XCTAssertNoThrow(foo = try rows?.first?.decode(String.self, context: .default)) XCTAssertEqual(foo, "hello") } - + func testDecodeIntegers() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" SELECT @@ -151,88 +151,89 @@ final class IntegrationTests: XCTestCase { -9223372036854775807::BIGINT as bigint_min, 9223372036854775807::BIGINT as bigint_max """, logger: .psqlTest).wait()) - - var rows: [PSQLRow]? + + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) let row = rows?.first - - XCTAssertEqual(try row?.decode(column: "smallint", as: Int16.self), 1) - XCTAssertEqual(try row?.decode(column: "smallint_min", as: Int16.self), -32_767) - XCTAssertEqual(try row?.decode(column: "smallint_max", as: Int16.self), 32_767) - XCTAssertEqual(try row?.decode(column: "int", as: Int32.self), 1) - XCTAssertEqual(try row?.decode(column: "int_min", as: Int32.self), -2_147_483_647) - XCTAssertEqual(try row?.decode(column: "int_max", as: Int32.self), 2_147_483_647) - XCTAssertEqual(try row?.decode(column: "bigint", as: Int64.self), 1) - XCTAssertEqual(try row?.decode(column: "bigint_min", as: Int64.self), -9_223_372_036_854_775_807) - XCTAssertEqual(try row?.decode(column: "bigint_max", as: Int64.self), 9_223_372_036_854_775_807) + + var cells: (Int16, Int16, Int16, Int32, Int32, Int32, Int64, Int64, Int64)? + XCTAssertNoThrow(cells = try row?.decode((Int16, Int16, Int16, Int32, Int32, Int32, Int64, Int64, Int64).self, context: .default)) + + XCTAssertEqual(cells?.0, 1) + XCTAssertEqual(cells?.1, -32_767) + XCTAssertEqual(cells?.2, 32_767) + XCTAssertEqual(cells?.3, 1) + XCTAssertEqual(cells?.4, -2_147_483_647) + XCTAssertEqual(cells?.5, 2_147_483_647) + XCTAssertEqual(cells?.6, 1) + XCTAssertEqual(cells?.7, -9_223_372_036_854_775_807) + XCTAssertEqual(cells?.8, 9_223_372_036_854_775_807) } - + func testEncodeAndDecodeIntArray() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? let array: [Int64] = [1, 2, 3] XCTAssertNoThrow(stream = try conn?.query("SELECT \(array)::int8[] as array", logger: .psqlTest).wait()) - - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(try rows?.first?.decode(column: "array", as: [Int64].self), array) + XCTAssertEqual(try rows?.first?.decode([Int64].self, context: .default), array) } - + func testDecodeEmptyIntegerArray() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query("SELECT '{}'::int[] as array", logger: .psqlTest).wait()) - - var rows: [PSQLRow]? + + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(try rows?.first?.decode(column: "array", as: [Int64].self), []) + XCTAssertEqual(try rows?.first?.decode([Int64].self, context: .default), []) } - + func testDoubleArraySerialization() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? let doubles: [Double] = [3.14, 42] XCTAssertNoThrow(stream = try conn?.query("SELECT \(doubles)::double precision[] as doubles", logger: .psqlTest).wait()) - - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(try rows?.first?.decode(column: "doubles", as: [Double].self), doubles) + XCTAssertEqual(try rows?.first?.decode([Double].self, context: .default), doubles) } - + func testDecodeDates() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" SELECT @@ -240,103 +241,105 @@ final class IntegrationTests: XCTestCase { '2016-01-18 01:02:03 +0042'::TIMESTAMP as timestamp, '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz """, logger: .psqlTest).wait()) - - var rows: [PSQLRow]? + + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) - let row = rows?.first - - XCTAssertEqual(try row?.decode(column: "date", as: Date.self).description, "2016-01-18 00:00:00 +0000") - XCTAssertEqual(try row?.decode(column: "timestamp", as: Date.self).description, "2016-01-18 01:02:03 +0000") - XCTAssertEqual(try row?.decode(column: "timestamptz", as: Date.self).description, "2016-01-18 00:20:03 +0000") + + var cells: (Date, Date, Date)? + XCTAssertNoThrow(cells = try rows?.first?.decode((Date, Date, Date).self, context: .default)) + + XCTAssertEqual(cells?.0.description, "2016-01-18 00:00:00 +0000") + XCTAssertEqual(cells?.1.description, "2016-01-18 01:02:03 +0000") + XCTAssertEqual(cells?.2.description, "2016-01-18 00:20:03 +0000") } - + func testDecodeDecimals() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" SELECT \(Decimal(string: "123456.789123")!)::numeric as numeric, \(Decimal(string: "-123456.789123")!)::numeric as numeric_negative """, logger: .psqlTest).wait()) - - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) - let row = rows?.first - - XCTAssertEqual(try row?.decode(column: "numeric", as: Decimal.self), Decimal(string: "123456.789123")!) - XCTAssertEqual(try row?.decode(column: "numeric_negative", as: Decimal.self), Decimal(string: "-123456.789123")!) + + var cells: (Decimal, Decimal)? + XCTAssertNoThrow(cells = try rows?.first?.decode((Decimal, Decimal).self, context: .default)) + + XCTAssertEqual(cells?.0, Decimal(string: "123456.789123")) + XCTAssertEqual(cells?.1, Decimal(string: "-123456.789123")) } - + func testDecodeUUID() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" SELECT '2c68f645-9ca6-468b-b193-ee97f241c2f8'::UUID as uuid """, logger: .psqlTest).wait()) - - var rows: [PSQLRow]? + + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) - - XCTAssertEqual(try rows?.first?.decode(column: "uuid", as: UUID.self), UUID(uuidString: "2c68f645-9ca6-468b-b193-ee97f241c2f8")) + XCTAssertEqual(try rows?.first?.decode(UUID.self, context: .default), UUID(uuidString: "2c68f645-9ca6-468b-b193-ee97f241c2f8")) } - + func testRoundTripJSONB() { struct Object: Codable, PostgresCodable { let foo: Int let bar: Int } - + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - + var conn: PSQLConnection? XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - + do { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" select \(Object(foo: 1, bar: 2))::jsonb as jsonb """, logger: .psqlTest).wait()) - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) var result: Object? - XCTAssertNoThrow(result = try rows?.first?.decode(column: "jsonb", as: Object.self)) + XCTAssertNoThrow(result = try rows?.first?.decode(Object.self, context: .default)) XCTAssertEqual(result?.foo, 1) XCTAssertEqual(result?.bar, 2) } - + do { var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" select \(Object(foo: 1, bar: 2))::json as json """, logger: .psqlTest).wait()) - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) var result: Object? - XCTAssertNoThrow(result = try rows?.first?.decode(column: "json", as: Object.self)) + XCTAssertNoThrow(result = try rows?.first?.decode(Object.self, context: .default)) XCTAssertEqual(result?.foo, 1) XCTAssertEqual(result?.bar, 2) } @@ -345,7 +348,7 @@ final class IntegrationTests: XCTestCase { extension PSQLConnection { - + static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { var logger = Logger(label: "psql.connection.test") logger.logLevel = logLevel @@ -356,8 +359,11 @@ extension PSQLConnection { database: env("POSTGRES_DB") ?? "test_database", password: env("POSTGRES_PASSWORD") ?? "test_password", tlsConfiguration: nil) - + return PSQLConnection.connect(configuration: config, logger: logger, on: eventLoop) } - +} + +extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { + static let `default`: Self = PostgresDecodingContext(jsonDecoder: JSONDecoder()) } diff --git a/Tests/IntegrationTests/PerformanceTests.swift b/Tests/IntegrationTests/PerformanceTests.swift index 7e74a595..59a2392a 100644 --- a/Tests/IntegrationTests/PerformanceTests.swift +++ b/Tests/IntegrationTests/PerformanceTests.swift @@ -2,7 +2,7 @@ import XCTest import Logging import NIOCore import NIOPosix -import PostgresNIO +@testable import PostgresNIO import NIOTestUtils final class PerformanceTests: XCTestCase { @@ -38,7 +38,7 @@ final class PerformanceTests: XCTestCase { do { for _ in 0..<5 { try conn.query("SELECT * FROM generate_series(1, 10000) num") { row in - _ = row.column("num")?.int + _ = try row.decode(Int.self, context: .default) }.wait() } } catch { @@ -65,8 +65,8 @@ final class PerformanceTests: XCTestCase { measure { do { try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("int")?.int - }.wait() + _ = try row.decode(Int.self, context: .default) + }.wait() } catch { XCTFail("\(error)") } @@ -101,12 +101,13 @@ final class PerformanceTests: XCTestCase { measure { do { - try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("id")?.int - _ = row.column("string")?.string - _ = row.column("int")?.int - _ = row.column("date")?.date - _ = row.column("uuid")?.uuid + try conn.query("SELECT * FROM \"measureSelectPerformance\"") { + let row = $0.makeRandomAccess() + _ = row[data: "id"].int + _ = row[data: "string"].string + _ = row[data: "int"].int + _ = row[data: "date"].date + _ = row[data: "uuid"].uuid }.wait() } catch { XCTFail("\(error)") @@ -174,28 +175,29 @@ final class PerformanceTests: XCTestCase { measure { do { - try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("id")?.int - _ = row.column("string1")?.string - _ = row.column("string2")?.string - _ = row.column("string3")?.string - _ = row.column("string4")?.string - _ = row.column("string5")?.string - _ = row.column("int1")?.int - _ = row.column("int2")?.int - _ = row.column("int3")?.int - _ = row.column("int4")?.int - _ = row.column("int5")?.int - _ = row.column("date1")?.date - _ = row.column("date2")?.date - _ = row.column("date3")?.date - _ = row.column("date4")?.date - _ = row.column("date5")?.date - _ = row.column("uuid1")?.uuid - _ = row.column("uuid2")?.uuid - _ = row.column("uuid3")?.uuid - _ = row.column("uuid4")?.uuid - _ = row.column("uuid5")?.uuid + try conn.query("SELECT * FROM \"measureSelectPerformance\"") { + let row = $0.makeRandomAccess() + _ = row[data: "id"].int + _ = row[data: "string1"].string + _ = row[data: "string2"].string + _ = row[data: "string3"].string + _ = row[data: "string4"].string + _ = row[data: "string5"].string + _ = row[data: "int1"].int + _ = row[data: "int2"].int + _ = row[data: "int3"].int + _ = row[data: "int4"].int + _ = row[data: "int5"].int + _ = row[data: "date1"].date + _ = row[data: "date2"].date + _ = row[data: "date3"].date + _ = row[data: "date4"].date + _ = row[data: "date5"].date + _ = row[data: "uuid1"].uuid + _ = row[data: "uuid2"].uuid + _ = row[data: "uuid3"].uuid + _ = row[data: "uuid4"].uuid + _ = row[data: "uuid5"].uuid }.wait() } catch { XCTFail("\(error)") @@ -219,10 +221,11 @@ final class PerformanceTests: XCTestCase { measure { do { - try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("id")?.int + try conn.query("SELECT * FROM \"measureSelectPerformance\"") { + let row = $0.makeRandomAccess() + _ = row[data: "id"].int for fieldName in fieldNames { - _ = row.column(fieldName)?.int + _ = row[data: fieldName].int } }.wait() } catch { @@ -247,10 +250,11 @@ final class PerformanceTests: XCTestCase { measure { do { - try conn.query("SELECT * FROM \"measureSelectPerformance\"") { row in - _ = row.column("id")?.int + try conn.query("SELECT * FROM \"measureSelectPerformance\"") { + let row = $0.makeRandomAccess() + _ = row[data: "id"].int for fieldName in fieldNames { - _ = row.column(fieldName)?.int + _ = row[data: fieldName].int } }.wait() } catch { diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index ff1fb804..7be9bab7 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -38,7 +38,7 @@ final class PostgresNIOTests: XCTestCase { var rows: [PostgresRow]? XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(rows?.first?.column("version")?.string?.contains("PostgreSQL"), true) + XCTAssertEqual(try rows?.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) } func testQueryVersion() { @@ -48,7 +48,7 @@ final class PostgresNIOTests: XCTestCase { var rows: PostgresQueryResult? XCTAssertNoThrow(rows = try conn?.query("SELECT version()", .init()).wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(rows?.first?.column("version")?.string?.contains("PostgreSQL"), true) + XCTAssertEqual(try rows?.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) } func testQuerySelectParameter() { @@ -58,7 +58,7 @@ final class PostgresNIOTests: XCTestCase { var rows: PostgresQueryResult? XCTAssertNoThrow(rows = try conn?.query("SELECT $1::TEXT as foo", ["hello"]).wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(rows?.first?.column("foo")?.string, "hello") + XCTAssertEqual(try rows?.first?.decode(String.self, context: .default), "hello") } func testSQLError() throws { @@ -240,11 +240,11 @@ final class PostgresNIOTests: XCTestCase { // "typoutput": "float8out" // ] XCTAssertEqual(results?.count, 1) - let row = results?.first - XCTAssertEqual(row?.column("typname")?.string, "float8") - XCTAssertEqual(row?.column("typnamespace")?.int, 11) - XCTAssertEqual(row?.column("typowner")?.int, 10) - XCTAssertEqual(row?.column("typlen")?.int, 8) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "typname"].string, "float8") + XCTAssertEqual(row?[data: "typnamespace"].int, 11) + XCTAssertEqual(row?[data: "typowner"].int, 10) + XCTAssertEqual(row?[data: "typlen"].int, 8) } func testIntegers() { @@ -277,16 +277,16 @@ final class PostgresNIOTests: XCTestCase { """).wait()) XCTAssertEqual(results?.count, 1) - let row = results?.first - XCTAssertEqual(row?.column("smallint")?.int16, 1) - XCTAssertEqual(row?.column("smallint_min")?.int16, -32_767) - XCTAssertEqual(row?.column("smallint_max")?.int16, 32_767) - XCTAssertEqual(row?.column("int")?.int32, 1) - XCTAssertEqual(row?.column("int_min")?.int32, -2_147_483_647) - XCTAssertEqual(row?.column("int_max")?.int32, 2_147_483_647) - XCTAssertEqual(row?.column("bigint")?.int64, 1) - XCTAssertEqual(row?.column("bigint_min")?.int64, -9_223_372_036_854_775_807) - XCTAssertEqual(row?.column("bigint_max")?.int64, 9_223_372_036_854_775_807) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "smallint"].int16, 1) + XCTAssertEqual(row?[data: "smallint_min"].int16, -32_767) + XCTAssertEqual(row?[data: "smallint_max"].int16, 32_767) + XCTAssertEqual(row?[data: "int"].int32, 1) + XCTAssertEqual(row?[data: "int_min"].int32, -2_147_483_647) + XCTAssertEqual(row?[data: "int_max"].int32, 2_147_483_647) + XCTAssertEqual(row?[data: "bigint"].int64, 1) + XCTAssertEqual(row?[data: "bigint_min"].int64, -9_223_372_036_854_775_807) + XCTAssertEqual(row?[data: "bigint_max"].int64, 9_223_372_036_854_775_807) } func testPi() { @@ -311,13 +311,13 @@ final class PostgresNIOTests: XCTestCase { pi()::FLOAT4 as float """).wait()) XCTAssertEqual(results?.count, 1) - let row = results?.first - XCTAssertEqual(row?.column("text")?.string?.hasPrefix("3.14159265"), true) - XCTAssertEqual(row?.column("numeric_string")?.string?.hasPrefix("3.14159265"), true) - XCTAssertTrue(row?.column("numeric_decimal")?.decimal?.isLess(than: 3.14159265358980) ?? false) - XCTAssertFalse(row?.column("numeric_decimal")?.decimal?.isLess(than: 3.14159265358978) ?? true) - XCTAssertTrue(row?.column("double")?.double?.description.hasPrefix("3.141592") ?? false) - XCTAssertTrue(row?.column("float")?.float?.description.hasPrefix("3.141592") ?? false) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "text"].string?.hasPrefix("3.14159265"), true) + XCTAssertEqual(row?[data: "numeric_string"].string?.hasPrefix("3.14159265"), true) + XCTAssertTrue(row?[data: "numeric_decimal"].decimal?.isLess(than: 3.14159265358980) ?? false) + XCTAssertFalse(row?[data: "numeric_decimal"].decimal?.isLess(than: 3.14159265358978) ?? true) + XCTAssertTrue(row?[data: "double"].double?.description.hasPrefix("3.141592") ?? false) + XCTAssertTrue(row?[data: "float"].float?.description.hasPrefix("3.141592") ?? false) } func testUUID() { @@ -335,8 +335,9 @@ final class PostgresNIOTests: XCTestCase { '123e4567-e89b-12d3-a456-426655440000'::UUID as string """).wait()) XCTAssertEqual(results?.count, 1) - XCTAssertEqual(results?.first?.column("id")?.uuid, UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) - XCTAssertEqual(UUID(uuidString: results?.first?.column("id")?.string ?? ""), UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "id"].uuid, UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) + XCTAssertEqual(UUID(uuidString: row?[data: "id"].string ?? ""), UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) } func testDates() { @@ -356,10 +357,10 @@ final class PostgresNIOTests: XCTestCase { '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz """).wait()) XCTAssertEqual(results?.count, 1) - let row = results?.first - XCTAssertEqual(row?.column("date")?.date?.description, "2016-01-18 00:00:00 +0000") - XCTAssertEqual(row?.column("timestamp")?.date?.description, "2016-01-18 01:02:03 +0000") - XCTAssertEqual(row?.column("timestamptz")?.date?.description, "2016-01-18 00:20:03 +0000") + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "date"].date?.description, "2016-01-18 00:00:00 +0000") + XCTAssertEqual(row?[data: "timestamp"].date?.description, "2016-01-18 01:02:03 +0000") + XCTAssertEqual(row?[data: "timestamptz"].date?.description, "2016-01-18 00:20:03 +0000") } /// https://github.com/vapor/nio-postgres/issues/20 @@ -381,7 +382,8 @@ final class PostgresNIOTests: XCTestCase { defer { XCTAssertNoThrow( try conn?.close().wait() ) } var results: PostgresQueryResult? XCTAssertNoThrow(results = try conn?.query("select avg(length('foo')) as average_length").wait()) - XCTAssertEqual(results?.first?.column("average_length")?.double, 3.0) + let row = results?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: 0].double, 3.0) } func testNumericParsing() { @@ -406,18 +408,18 @@ final class PostgresNIOTests: XCTestCase { '0.5'::numeric as m """).wait()) XCTAssertEqual(rows?.count, 1) - let row = rows?.first - XCTAssertEqual(row?.column("a")?.string, "1234.5678") - XCTAssertEqual(row?.column("b")?.string, "-123.456") - XCTAssertEqual(row?.column("c")?.string, "123456.789123") - XCTAssertEqual(row?.column("d")?.string, "3.14159265358979") - XCTAssertEqual(row?.column("e")?.string, "10000") - XCTAssertEqual(row?.column("f")?.string, "0.00001") - XCTAssertEqual(row?.column("g")?.string, "100000000") - XCTAssertEqual(row?.column("h")?.string, "0.000000001") - XCTAssertEqual(row?.column("k")?.string, "123000000000") - XCTAssertEqual(row?.column("l")?.string, "0.000000000123") - XCTAssertEqual(row?.column("m")?.string, "0.5") + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "a"].string, "1234.5678") + XCTAssertEqual(row?[data: "b"].string, "-123.456") + XCTAssertEqual(row?[data: "c"].string, "123456.789123") + XCTAssertEqual(row?[data: "d"].string, "3.14159265358979") + XCTAssertEqual(row?[data: "e"].string, "10000") + XCTAssertEqual(row?[data: "f"].string, "0.00001") + XCTAssertEqual(row?[data: "g"].string, "100000000") + XCTAssertEqual(row?[data: "h"].string, "0.000000001") + XCTAssertEqual(row?[data: "k"].string, "123000000000") + XCTAssertEqual(row?[data: "l"].string, "0.000000000123") + XCTAssertEqual(row?[data: "m"].string, "0.5") } func testSingleNumericParsing() { @@ -431,7 +433,8 @@ final class PostgresNIOTests: XCTestCase { select '\(numeric)'::numeric as n """).wait()) - XCTAssertEqual(rows?.first?.column("n")?.string, numeric) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "n"].string, numeric) } func testRandomlyGeneratedNumericParsing() throws { @@ -452,7 +455,8 @@ final class PostgresNIOTests: XCTestCase { select '\(number)'::numeric as n """).wait()) - XCTAssertEqual(rows?.first?.column("n")?.string, number) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "n"].string, number) } } @@ -474,9 +478,10 @@ final class PostgresNIOTests: XCTestCase { .init(numeric: b), .init(numeric: c) ]).wait()) - XCTAssertEqual(rows?.first?.column("a")?.decimal, Decimal(string: "123456.789123")!) - XCTAssertEqual(rows?.first?.column("b")?.decimal, Decimal(string: "-123456.789123")!) - XCTAssertEqual(rows?.first?.column("c")?.decimal, Decimal(string: "3.14159265358979")!) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "a"].decimal, Decimal(string: "123456.789123")!) + XCTAssertEqual(row?[data: "b"].decimal, Decimal(string: "-123456.789123")!) + XCTAssertEqual(row?[data: "c"].decimal, Decimal(string: "3.14159265358979")!) } func testDecimalStringSerialization() { @@ -500,7 +505,8 @@ final class PostgresNIOTests: XCTestCase { "balance" FROM table1 """).wait()) - XCTAssertEqual(rows?.first?.column("balance")?.decimal, Decimal(string: "123456.789123")!) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "balance"].decimal, Decimal(string: "123456.789123")!) } func testMoney() { @@ -516,11 +522,12 @@ final class PostgresNIOTests: XCTestCase { '3.14'::money as d, '12345678.90'::money as e """).wait()) - XCTAssertEqual(rows?.first?.column("a")?.string, "0.00") - XCTAssertEqual(rows?.first?.column("b")?.string, "0.05") - XCTAssertEqual(rows?.first?.column("c")?.string, "0.23") - XCTAssertEqual(rows?.first?.column("d")?.string, "3.14") - XCTAssertEqual(rows?.first?.column("e")?.string, "12345678.90") + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "a"].string, "0.00") + XCTAssertEqual(row?[data: "b"].string, "0.05") + XCTAssertEqual(row?[data: "c"].string, "0.23") + XCTAssertEqual(row?[data: "d"].string, "3.14") + XCTAssertEqual(row?[data: "e"].string, "12345678.90") } func testIntegerArrayParse() { @@ -532,7 +539,8 @@ final class PostgresNIOTests: XCTestCase { select '{1,2,3}'::int[] as array """).wait()) - XCTAssertEqual(rows?.first?.column("array")?.array(of: Int.self), [1, 2, 3]) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Int.self), [1, 2, 3]) } func testEmptyIntegerArrayParse() { @@ -544,7 +552,8 @@ final class PostgresNIOTests: XCTestCase { select '{}'::int[] as array """).wait()) - XCTAssertEqual(rows?.first?.column("array")?.array(of: Int.self), []) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Int.self), []) } func testNullIntegerArrayParse() { @@ -556,7 +565,8 @@ final class PostgresNIOTests: XCTestCase { select null::int[] as array """).wait()) - XCTAssertEqual(rows?.first?.column("array")?.array(of: Int.self), nil) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Int.self), nil) } func testIntegerArraySerialize() { @@ -570,7 +580,8 @@ final class PostgresNIOTests: XCTestCase { """, [ PostgresData(array: [1, 2, 3]) ]).wait()) - XCTAssertEqual(rows?.first?.column("array")?.array(of: Int.self), [1, 2, 3]) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Int.self), [1, 2, 3]) } func testEmptyIntegerArraySerialize() { @@ -584,7 +595,8 @@ final class PostgresNIOTests: XCTestCase { """, [ PostgresData(array: [] as [Int]) ]).wait()) - XCTAssertEqual(rows?.first?.column("array")?.array(of: Int.self), []) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Int.self), []) } // https://github.com/vapor/postgres-nio/issues/143 @@ -610,7 +622,8 @@ final class PostgresNIOTests: XCTestCase { var rows: [PostgresRow]? XCTAssertNoThrow(rows = try conn?.simpleQuery(#"SELECT * FROM "non_null_empty_strings""#).wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(rows?.first?.column("nonNullString")?.string, "") // <--- this fails + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "nonNullString"].string, "") // <--- this fails } @@ -621,22 +634,26 @@ final class PostgresNIOTests: XCTestCase { do { var rows: PostgresQueryResult? XCTAssertNoThrow(rows = try conn?.query("select $1::bool as bool", [true]).wait()) - XCTAssertEqual(rows?.first?.column("bool")?.bool, true) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "bool"].bool, true) } do { var rows: PostgresQueryResult? XCTAssertNoThrow(rows = try conn?.query("select $1::bool as bool", [false]).wait()) - XCTAssertEqual(rows?.first?.column("bool")?.bool, false) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "bool"].bool, false) } do { var rows: [PostgresRow]? XCTAssertNoThrow(rows = try conn?.simpleQuery("select true::bool as bool").wait()) - XCTAssertEqual(rows?.first?.column("bool")?.bool, true) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "bool"].bool, true) } do { var rows: [PostgresRow]? XCTAssertNoThrow(rows = try conn?.simpleQuery("select false::bool as bool").wait()) - XCTAssertEqual(rows?.first?.column("bool")?.bool, false) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "bool"].bool, false) } } @@ -648,11 +665,12 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(rows = try conn?.query("select $1::bytea as bytes", [ PostgresData(bytes: [1, 2, 3]) ]).wait()) - XCTAssertEqual(rows?.first?.column("bytes")?.bytes, [1, 2, 3]) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "bytes"].bytes, [1, 2, 3]) } func testJSONBSerialize() { - struct Object: Codable { + struct Object: Codable, PostgresCodable { let foo: Int let bar: Int } @@ -667,7 +685,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(rows = try conn?.query("select $1::jsonb as jsonb", [XCTUnwrap(postgresData)]).wait()) var object: Object? - XCTAssertNoThrow(object = try rows?.first?.column("jsonb")?.jsonb(as: Object.self)) + XCTAssertNoThrow(object = try rows?.first?.decode(Object.self, context: .default)) XCTAssertEqual(object?.foo, 1) XCTAssertEqual(object?.bar, 2) } @@ -677,7 +695,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(rows = try conn?.query("select jsonb_build_object('foo',1,'bar',2) as jsonb").wait()) var object: Object? - XCTAssertNoThrow(object = try rows?.first?.column("jsonb")?.jsonb(as: Object.self)) + XCTAssertNoThrow(object = try rows?.first?.decode(Object.self, context: .default)) XCTAssertEqual(object?.foo, 1) XCTAssertEqual(object?.bar, 2) } @@ -701,8 +719,8 @@ final class PostgresNIOTests: XCTestCase { var rows: [PostgresRow]? XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) XCTAssertEqual(rows?.count, 1) - let version = rows?.first?.column("version")?.string - XCTAssertEqual(version?.contains("PostgreSQL"), true) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "version"].string?.contains("PostgreSQL"), true) } func testFailingTLSConnectionClosesConnection() { @@ -787,16 +805,16 @@ final class PostgresNIOTests: XCTestCase { * FROM table1 INNER JOIN table2 ON table1.table2_id = table2.id """).wait()) - let row = rows?.first - XCTAssertEqual(row?.column("t1_id")?.int, 12) - XCTAssertEqual(row?.column("table2_id")?.int, 34) - XCTAssertEqual(row?.column("t1_intValue")?.int, 56) - XCTAssertEqual(row?.column("t1_stringValue")?.string, "stringInTable1") - XCTAssertEqual(row?.column("t1_dateValue")?.date, dateInTable1) - XCTAssertEqual(row?.column("t2_id")?.int, 34) - XCTAssertEqual(row?.column("t2_intValue")?.int, 78) - XCTAssertEqual(row?.column("t2_stringValue")?.string, "stringInTable2") - XCTAssertEqual(row?.column("t2_dateValue")?.date, dateInTable2) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "t1_id"].int, 12) + XCTAssertEqual(row?[data: "table2_id"].int, 34) + XCTAssertEqual(row?[data: "t1_intValue"].int, 56) + XCTAssertEqual(row?[data: "t1_stringValue"].string, "stringInTable1") + XCTAssertEqual(row?[data: "t1_dateValue"].date, dateInTable1) + XCTAssertEqual(row?[data: "t2_id"].int, 34) + XCTAssertEqual(row?[data: "t2_intValue"].int, 78) + XCTAssertEqual(row?[data: "t2_stringValue"].string, "stringInTable2") + XCTAssertEqual(row?[data: "t2_dateValue"].date, dateInTable2) } func testStringArrays() { @@ -826,9 +844,10 @@ final class PostgresNIOTests: XCTestCase { PostgresData(array: ["en"]), PostgresData(array: ["USD", "DKK"]), ]).wait()) - XCTAssertEqual(rows?.first?.column("countries")?.array(of: String.self), ["US"]) - XCTAssertEqual(rows?.first?.column("languages")?.array(of: String.self), ["en"]) - XCTAssertEqual(rows?.first?.column("currencies")?.array(of: String.self), ["USD", "DKK"]) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "countries"].array(of: String.self), ["US"]) + XCTAssertEqual(row?[data: "languages"].array(of: String.self), ["en"]) + XCTAssertEqual(row?[data: "currencies"].array(of: String.self), ["USD", "DKK"]) } func testBindDate() { @@ -861,7 +880,8 @@ final class PostgresNIOTests: XCTestCase { defer { XCTAssertNoThrow( try conn?.close().wait() ) } var rows: PostgresQueryResult? XCTAssertNoThrow(rows = try conn?.query(query, [.init(string: "f")]).wait()) - XCTAssertEqual(rows?.first?.column("char")?.string, "f") + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "char"].string, "f") } func testBindCharUInt8() { @@ -874,7 +894,8 @@ final class PostgresNIOTests: XCTestCase { defer { XCTAssertNoThrow( try conn?.close().wait() ) } var rows: PostgresQueryResult? XCTAssertNoThrow(rows = try conn?.query(query, [.init(uint8: 42)]).wait()) - XCTAssertEqual(rows?.first?.column("char")?.string, "*") + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "char"].string, "*") } func testDoubleArraySerialization() { @@ -889,7 +910,8 @@ final class PostgresNIOTests: XCTestCase { """, [ .init(array: doubles) ]).wait()) - XCTAssertEqual(rows?.first?.column("doubles")?.array(of: Double.self), doubles) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "doubles"].array(of: Double.self), doubles) } // https://github.com/vapor/postgres-nio/issues/42 @@ -904,7 +926,8 @@ final class PostgresNIOTests: XCTestCase { """, [ .init(uint8: 5) ]).wait()) - XCTAssertEqual(rows?.first?.column("int")?.uint8, 5) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "int"].uint8, 5) } func testPreparedQuery() { @@ -917,7 +940,8 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(rows = try prepared?.execute().wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(rows?.first?.column("one")?.int, 1) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "one"].int, 1) } func testPrepareQueryClosure() { @@ -932,10 +956,10 @@ final class PostgresNIOTests: XCTestCase { return EventLoopFuture.whenAllSucceed([a, b, c], on: self.eventLoop) }).wait()) XCTAssertEqual(queries?.count, 3) - var iterator = queries?.makeIterator() - XCTAssertEqual(iterator?.next()?.first?.column("foo")?.string, "a") - XCTAssertEqual(iterator?.next()?.first?.column("foo")?.string, "b") - XCTAssertEqual(iterator?.next()?.first?.column("foo")?.string, "c") + var resutIterator = queries?.makeIterator() + XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "a") + XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "b") + XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "c") } // https://github.com/vapor/postgres-nio/issues/122 @@ -970,12 +994,13 @@ final class PostgresNIOTests: XCTestCase { '5'::char(2) as two """).wait()) - XCTAssertEqual(rows?.first?.column("one")?.uint8, 53) - XCTAssertEqual(rows?.first?.column("one")?.int16, 53) - XCTAssertEqual(rows?.first?.column("one")?.string, "5") - XCTAssertEqual(rows?.first?.column("two")?.uint8, nil) - XCTAssertEqual(rows?.first?.column("two")?.int16, nil) - XCTAssertEqual(rows?.first?.column("two")?.string, "5 ") + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "one"].uint8, 53) + XCTAssertEqual(row?[data: "one"].int16, 53) + XCTAssertEqual(row?[data: "one"].string, "5") + XCTAssertEqual(row?[data: "two"].uint8, nil) + XCTAssertEqual(row?[data: "two"].int16, nil) + XCTAssertEqual(row?[data: "two"].string, "5 ") } func testUserDefinedType() { @@ -990,7 +1015,8 @@ final class PostgresNIOTests: XCTestCase { } var res: PostgresQueryResult? XCTAssertNoThrow(res = try conn?.query("SELECT 'qux'::foo as foo").wait()) - XCTAssertEqual(res?.first?.column("foo")?.string, "qux") + let row = res?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "foo"].string, "qux") } func testNullBind() { @@ -1000,7 +1026,8 @@ final class PostgresNIOTests: XCTestCase { var res: PostgresQueryResult? XCTAssertNoThrow(res = try conn?.query("SELECT $1::text as foo", [String?.none.postgresData!]).wait()) - XCTAssertEqual(res?.first?.column("foo")?.string, nil) + let row = res?.first?.makeRandomAccess() + XCTAssertNil(row?[data: "foo"].string) } func testUpdateMetadata() { @@ -1048,7 +1075,8 @@ final class PostgresNIOTests: XCTestCase { var res: PostgresQueryResult? XCTAssertNoThrow(res = try conn?.query(#"SELECT '{"foo", "bar", "baz"}'::VARCHAR[] as foo"#).wait()) - XCTAssertEqual(res?.first?.column("foo")?.array(of: String.self), ["foo", "bar", "baz"]) + let row = res?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "foo"].array(of: String.self), ["foo", "bar", "baz"]) } // https://github.com/vapor/postgres-nio/issues/115 @@ -1079,37 +1107,38 @@ final class PostgresNIOTests: XCTestCase { '-9223372036854775808'::bigint as min64, '9223372036854775807'::bigint as max64 """).wait()) - XCTAssertEqual(rows?.first?.column("test8")?.uint8, 97) - XCTAssertEqual(rows?.first?.column("test8")?.int16, 97) - XCTAssertEqual(rows?.first?.column("test8")?.int32, 97) - XCTAssertEqual(rows?.first?.column("test8")?.int64, 97) - - XCTAssertEqual(rows?.first?.column("min16")?.uint8, nil) - XCTAssertEqual(rows?.first?.column("max16")?.uint8, nil) - XCTAssertEqual(rows?.first?.column("min16")?.int16, .min) - XCTAssertEqual(rows?.first?.column("max16")?.int16, .max) - XCTAssertEqual(rows?.first?.column("min16")?.int32, -32768) - XCTAssertEqual(rows?.first?.column("max16")?.int32, 32767) - XCTAssertEqual(rows?.first?.column("min16")?.int64, -32768) - XCTAssertEqual(rows?.first?.column("max16")?.int64, 32767) - - XCTAssertEqual(rows?.first?.column("min32")?.uint8, nil) - XCTAssertEqual(rows?.first?.column("max32")?.uint8, nil) - XCTAssertEqual(rows?.first?.column("min32")?.int16, nil) - XCTAssertEqual(rows?.first?.column("max32")?.int16, nil) - XCTAssertEqual(rows?.first?.column("min32")?.int32, .min) - XCTAssertEqual(rows?.first?.column("max32")?.int32, .max) - XCTAssertEqual(rows?.first?.column("min32")?.int64, -2147483648) - XCTAssertEqual(rows?.first?.column("max32")?.int64, 2147483647) - - XCTAssertEqual(rows?.first?.column("min64")?.uint8, nil) - XCTAssertEqual(rows?.first?.column("max64")?.uint8, nil) - XCTAssertEqual(rows?.first?.column("min64")?.int16, nil) - XCTAssertEqual(rows?.first?.column("max64")?.int16, nil) - XCTAssertEqual(rows?.first?.column("min64")?.int32, nil) - XCTAssertEqual(rows?.first?.column("max64")?.int32, nil) - XCTAssertEqual(rows?.first?.column("min64")?.int64, .min) - XCTAssertEqual(rows?.first?.column("max64")?.int64, .max) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "test8"].uint8, 97) + XCTAssertEqual(row?[data: "test8"].int16, 97) + XCTAssertEqual(row?[data: "test8"].int32, 97) + XCTAssertEqual(row?[data: "test8"].int64, 97) + + XCTAssertEqual(row?[data: "min16"].uint8, nil) + XCTAssertEqual(row?[data: "max16"].uint8, nil) + XCTAssertEqual(row?[data: "min16"].int16, .min) + XCTAssertEqual(row?[data: "max16"].int16, .max) + XCTAssertEqual(row?[data: "min16"].int32, -32768) + XCTAssertEqual(row?[data: "max16"].int32, 32767) + XCTAssertEqual(row?[data: "min16"].int64, -32768) + XCTAssertEqual(row?[data: "max16"].int64, 32767) + + XCTAssertEqual(row?[data: "min32"].uint8, nil) + XCTAssertEqual(row?[data: "max32"].uint8, nil) + XCTAssertEqual(row?[data: "min32"].int16, nil) + XCTAssertEqual(row?[data: "max32"].int16, nil) + XCTAssertEqual(row?[data: "min32"].int32, .min) + XCTAssertEqual(row?[data: "max32"].int32, .max) + XCTAssertEqual(row?[data: "min32"].int64, -2147483648) + XCTAssertEqual(row?[data: "max32"].int64, 2147483647) + + XCTAssertEqual(row?[data: "min64"].uint8, nil) + XCTAssertEqual(row?[data: "max64"].uint8, nil) + XCTAssertEqual(row?[data: "min64"].int16, nil) + XCTAssertEqual(row?[data: "max64"].int16, nil) + XCTAssertEqual(row?[data: "min64"].int32, nil) + XCTAssertEqual(row?[data: "max64"].int32, nil) + XCTAssertEqual(row?[data: "min64"].int64, .min) + XCTAssertEqual(row?[data: "max64"].int64, .max) } } diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index dbf506fa..6a9dfbb5 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -81,7 +81,7 @@ class PSQLRowStreamTests: XCTestCase { let future = stream.all() XCTAssertEqual(dataSource.hitDemand, 0) // TODO: Is this right? - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try future.wait()) XCTAssertEqual(rows?.count, 2) } @@ -131,7 +131,7 @@ class PSQLRowStreamTests: XCTestCase { stream.receive(completion: .success("SELECT 2")) - var rows: [PSQLRow]? + var rows: [PostgresRow]? XCTAssertNoThrow(rows = try future.wait()) XCTAssertEqual(rows?.count, 6) } @@ -170,7 +170,7 @@ class PSQLRowStreamTests: XCTestCase { // attach consumer var counter = 0 let future = stream.onRow { row in - XCTAssertEqual(try row.decode(column: 0, as: String.self), "\(counter)") + XCTAssertEqual(try row.decode(String.self, context: .forTests()), "\(counter)") counter += 1 } XCTAssertEqual(counter, 2) @@ -214,7 +214,7 @@ class PSQLRowStreamTests: XCTestCase { // attach consumer var counter = 0 let future = stream.onRow { row in - XCTAssertEqual(try row.decode(column: 0, as: String.self), "\(counter)") + XCTAssertEqual(try row.decode(String.self, context: .forTests()), "\(counter)") if counter == 1 { throw OnRowError(row: counter) } @@ -261,7 +261,7 @@ class PSQLRowStreamTests: XCTestCase { // attach consumer var counter = 0 let future = stream.onRow { row in - XCTAssertEqual(try row.decode(column: 0, as: String.self), "\(counter)") + XCTAssertEqual(try row.decode(String.self, context: .forTests()), "\(counter)") counter += 1 } XCTAssertEqual(counter, 2) diff --git a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift index bf300c1f..0a3096e8 100644 --- a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift @@ -4,7 +4,7 @@ import XCTest final class PostgresCodableTests: XCTestCase { func testDecodeAnOptionalFromARow() { - let row = PSQLRow( + let row = PostgresRow( data: .makeTestDataRow(nil, ByteBuffer(string: "Hello world!")), lookupTable: ["id": 0, "name": 1], columns: [ @@ -36,7 +36,7 @@ final class PostgresCodableTests: XCTestCase { } func testDecodeMissingValueError() { - let row = PSQLRow( + let row = PostgresRow( data: .makeTestDataRow(nil), lookupTable: ["name": 0], columns: [ diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index d42beb85..9d6c467d 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -59,7 +59,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(column: 0, as: Int.self), counter) + XCTAssertEqual(try row.decode(Int.self, context: .forTests()), counter) counter += 1 if counter == 64 { @@ -141,7 +141,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(column: 0, as: Int.self), counter) + XCTAssertEqual(try row.decode(Int.self, context: .forTests()), counter) counter += 1 } @@ -171,7 +171,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(column: 0, as: Int.self), counter) + XCTAssertEqual(try row.decode(Int.self, context: .forTests()), counter) counter += 1 } @@ -232,7 +232,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(column: 0, as: Int.self), 0) + XCTAssertEqual(try row1?.decode(Int.self, context: .forTests()), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { stream.receive(completion: .success("SELECT 1")) @@ -266,7 +266,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(column: 0, as: Int.self), 0) + XCTAssertEqual(try row1?.decode(Int.self, context: .forTests()), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { stream.receive(completion: .failure(PSQLError.connectionClosed)) @@ -433,7 +433,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 1 for _ in 0..<(2 * messagePerChunk - 1) { let row = try await rowIterator.next() - XCTAssertEqual(try row?.decode(column: 0, as: Int.self), counter) + XCTAssertEqual(try row?.decode(Int.self, context: .forTests()), counter) counter += 1 } diff --git a/Tests/PostgresNIOTests/New/PostgresRowTests.swift b/Tests/PostgresNIOTests/New/PostgresRowTests.swift new file mode 100644 index 00000000..7a67823b --- /dev/null +++ b/Tests/PostgresNIOTests/New/PostgresRowTests.swift @@ -0,0 +1,124 @@ +import XCTest +@testable import PostgresNIO + +final class PostgresRowTests: XCTestCase { + + func testSequence() { + let rowDescription = [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .uuid, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ), + RowDescription.Column( + name: "name", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + + let row = PostgresRow( + data: .makeTestDataRow(nil, ByteBuffer(string: "Hello world!")), + lookupTable: ["id": 0, "name": 1], + columns: rowDescription + ) + + XCTAssertEqual(row.count, 2) + var iterator = row.makeIterator() + + XCTAssertEqual(iterator.next(), PostgresCell(bytes: nil, dataType: .uuid, format: .binary, columnName: "id", columnIndex: 0)) + XCTAssertEqual(iterator.next(), PostgresCell(bytes: ByteBuffer(string: "Hello world!"), dataType: .text, format: .binary, columnName: "name", columnIndex: 1)) + XCTAssertNil(iterator.next()) + } + + func testCollection() { + let rowDescription = [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .uuid, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ), + RowDescription.Column( + name: "name", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + + let row = PostgresRow( + data: .makeTestDataRow(nil, ByteBuffer(string: "Hello world!")), + lookupTable: ["id": 0, "name": 1], + columns: rowDescription + ) + + XCTAssertEqual(row.count, 2) + let startIndex = row.startIndex + let secondIndex = row.index(after: startIndex) + let endIndex = row.index(after: secondIndex) + XCTAssertLessThan(startIndex, secondIndex) + XCTAssertLessThan(secondIndex, endIndex) + XCTAssertEqual(endIndex, row.endIndex) + + XCTAssertEqual(row[startIndex], PostgresCell(bytes: nil, dataType: .uuid, format: .binary, columnName: "id", columnIndex: 0)) + XCTAssertEqual(row[secondIndex], PostgresCell(bytes: ByteBuffer(string: "Hello world!"), dataType: .text, format: .binary, columnName: "name", columnIndex: 1)) + } + + func testRandomAccessRow() { + let rowDescription = [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .uuid, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ), + RowDescription.Column( + name: "name", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + + let row = PostgresRow( + data: .makeTestDataRow(nil, ByteBuffer(string: "Hello world!")), + lookupTable: ["id": 0, "name": 1], + columns: rowDescription + ) + + let randomAccessRow = row.makeRandomAccess() + + XCTAssertEqual(randomAccessRow.count, 2) + let startIndex = randomAccessRow.startIndex + let endIndex = randomAccessRow.endIndex + XCTAssertEqual(startIndex, 0) + XCTAssertEqual(endIndex, 2) + + XCTAssertEqual(randomAccessRow[0], PostgresCell(bytes: nil, dataType: .uuid, format: .binary, columnName: "id", columnIndex: 0)) + XCTAssertEqual(randomAccessRow[1], PostgresCell(bytes: ByteBuffer(string: "Hello world!"), dataType: .text, format: .binary, columnName: "name", columnIndex: 1)) + + XCTAssertEqual(randomAccessRow["id"], PostgresCell(bytes: nil, dataType: .uuid, format: .binary, columnName: "id", columnIndex: 0)) + XCTAssertEqual(randomAccessRow["name"], PostgresCell(bytes: ByteBuffer(string: "Hello world!"), dataType: .text, format: .binary, columnName: "name", columnIndex: 1)) + } +} diff --git a/dev/generate-psqlrow-multi-decode.sh b/dev/generate-psqlrow-multi-decode.sh index 5fee4a93..84652339 100755 --- a/dev/generate-psqlrow-multi-decode.sh +++ b/dev/generate-psqlrow-multi-decode.sh @@ -11,7 +11,7 @@ function gen() { echo "" fi - echo " @inlinable" + #echo " @inlinable" #echo " @_alwaysEmitIntoClient" echo -n " func decode Date: Sat, 26 Feb 2022 02:21:58 +0100 Subject: [PATCH 066/292] Replace all EncoderContext/DecoderContext uses of .forTests with .default (#231) --- .../New/Data/Array+PSQLCodableTests.swift | 24 ++++++------- .../New/Data/Bool+PSQLCodableTests.swift | 18 +++++----- .../New/Data/Bytes+PSQLCodableTests.swift | 10 +++--- .../New/Data/Date+PSQLCodableTests.swift | 20 +++++------ .../New/Data/Decimal+PSQLCodableTests.swift | 6 ++-- .../New/Data/Float+PSQLCodableTests.swift | 36 +++++++++---------- .../New/Data/JSON+PSQLCodableTests.swift | 12 +++---- .../RawRepresentable+PSQLCodableTests.swift | 8 ++--- .../New/Data/String+PSQLCodableTests.swift | 14 ++++---- .../New/Data/UUID+PSQLCodableTests.swift | 16 ++++----- .../New/Extensions/PSQLCoding+TestUtils.swift | 14 -------- .../New/PSQLRowStreamTests.swift | 6 ++-- .../New/PostgresCellTests.swift | 6 ++-- .../New/PostgresRowSequenceTests.swift | 12 +++---- 14 files changed, 94 insertions(+), 108 deletions(-) delete mode 100644 Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index a155399f..62a6629f 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -61,10 +61,10 @@ class Array_PSQLCodableTests: XCTestCase { let values = ["foo", "bar", "hello", "world"] var buffer = ByteBuffer() - XCTAssertNoThrow(try values.encode(into: &buffer, context: .forTests())) + XCTAssertNoThrow(try values.encode(into: &buffer, context: .default)) var result: [String]? - XCTAssertNoThrow(result = try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) XCTAssertEqual(values, result) } @@ -72,10 +72,10 @@ class Array_PSQLCodableTests: XCTestCase { let values: [String] = [] var buffer = ByteBuffer() - XCTAssertNoThrow(try values.encode(into: &buffer, context: .forTests())) + XCTAssertNoThrow(try values.encode(into: &buffer, context: .default)) var result: [String]? - XCTAssertNoThrow(result = try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) XCTAssertEqual(values, result) } @@ -85,7 +85,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(0)) buffer.writeInteger(String.psqlArrayElementType.rawValue) - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -96,7 +96,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(1)) // invalid value, must always be 0 buffer.writeInteger(String.psqlArrayElementType.rawValue) - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -104,9 +104,9 @@ class Array_PSQLCodableTests: XCTestCase { func testDecodeFailureTriesDecodeInt8() { let value: Int64 = 1 << 32 var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -119,7 +119,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(-123)) // expected element count buffer.writeInteger(Int32(1)) // dimensions... must be one - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -132,7 +132,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(1)) // expected element count buffer.writeInteger(Int32(2)) // dimensions... must be one - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -146,7 +146,7 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 - XCTAssertThrowsError(try [String].decode(from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try [String].decode(from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } @@ -159,7 +159,7 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 unexpectedEndInElementBuffer.writeString("Hello World") // only 11 bytes, 12 needed! - XCTAssertThrowsError(try [String].decode(from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try [String].decode(from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index 773a35b8..f9c8103b 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -10,14 +10,14 @@ class Bool_PSQLCodableTests: XCTestCase { let value = true var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .bool) XCTAssertEqual(value.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) var result: Bool? - XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .default)) XCTAssertEqual(value, result) } @@ -25,14 +25,14 @@ class Bool_PSQLCodableTests: XCTestCase { let value = false var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .bool) XCTAssertEqual(value.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 0) var result: Bool? - XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .default)) XCTAssertEqual(value, result) } @@ -40,7 +40,7 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64(1)) - XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -49,7 +49,7 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -63,7 +63,7 @@ class Bool_PSQLCodableTests: XCTestCase { buffer.writeInteger(UInt8(ascii: "t")) var result: Bool? - XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .text, context: .forTests())) + XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .text, context: .default)) XCTAssertEqual(value, result) } @@ -74,7 +74,7 @@ class Bool_PSQLCodableTests: XCTestCase { buffer.writeInteger(UInt8(ascii: "f")) var result: Bool? - XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .text, context: .forTests())) + XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .text, context: .default)) XCTAssertEqual(value, result) } @@ -82,7 +82,7 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .text, context: .forTests())) { + XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .text, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index 9747ec19..1dee1e06 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -8,11 +8,11 @@ class Bytes_PSQLCodableTests: XCTestCase { let data = Data((0...UInt8.max)) var buffer = ByteBuffer() - data.encode(into: &buffer, context: .forTests()) + data.encode(into: &buffer, context: .default) XCTAssertEqual(data.psqlType, .bytea) var result: Data? - XCTAssertNoThrow(result = try Data.decode(from: &buffer, type: .bytea, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Data.decode(from: &buffer, type: .bytea, format: .binary, context: .default)) XCTAssertEqual(data, result) } @@ -20,11 +20,11 @@ class Bytes_PSQLCodableTests: XCTestCase { let bytes = ByteBuffer(bytes: (0...UInt8.max)) var buffer = ByteBuffer() - bytes.encode(into: &buffer, context: .forTests()) + bytes.encode(into: &buffer, context: .default) XCTAssertEqual(bytes.psqlType, .bytea) var result: ByteBuffer? - XCTAssertNoThrow(result = try ByteBuffer.decode(from: &buffer, type: .bytea, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try ByteBuffer.decode(from: &buffer, type: .bytea, format: .binary, context: .default)) XCTAssertEqual(bytes, result) } @@ -46,7 +46,7 @@ class Bytes_PSQLCodableTests: XCTestCase { let sequence = ByteSequence() var buffer = ByteBuffer() - sequence.encode(into: &buffer, context: .forTests()) + sequence.encode(into: &buffer, context: .default) XCTAssertEqual(sequence.psqlType, .bytea) XCTAssertEqual(buffer.readableBytes, 256) } diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 87eb46de..02bc4e97 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -8,12 +8,12 @@ class Date_PSQLCodableTests: XCTestCase { let value = Date() var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .timestamptz) XCTAssertEqual(buffer.readableBytes, 8) var result: Date? - XCTAssertNoThrow(result = try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .default)) XCTAssertEqual(value, result) } @@ -22,7 +22,7 @@ class Date_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) var result: Date? - XCTAssertNoThrow(result = try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .default)) XCTAssertNotNil(result) } @@ -31,7 +31,7 @@ class Date_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) - XCTAssertThrowsError(try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -41,14 +41,14 @@ class Date_PSQLCodableTests: XCTestCase { firstDateBuffer.writeInteger(Int32.min) var firstDate: Date? - XCTAssertNoThrow(firstDate = try Date.decode(from: &firstDateBuffer, type: .date, format: .binary, context: .forTests())) + XCTAssertNoThrow(firstDate = try Date.decode(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(firstDate) var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) var lastDate: Date? - XCTAssertNoThrow(lastDate = try Date.decode(from: &lastDateBuffer, type: .date, format: .binary, context: .forTests())) + XCTAssertNoThrow(lastDate = try Date.decode(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(lastDate) } @@ -57,14 +57,14 @@ class Date_PSQLCodableTests: XCTestCase { firstDateBuffer.writeInteger(Int32.min) var firstDate: Date? - XCTAssertNoThrow(firstDate = try Date.decode(from: &firstDateBuffer, type: .date, format: .binary, context: .forTests())) + XCTAssertNoThrow(firstDate = try Date.decode(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(firstDate) var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) var lastDate: Date? - XCTAssertNoThrow(lastDate = try Date.decode(from: &lastDateBuffer, type: .date, format: .binary, context: .forTests())) + XCTAssertNoThrow(lastDate = try Date.decode(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(lastDate) } @@ -72,7 +72,7 @@ class Date_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - XCTAssertThrowsError(try Date.decode(from: &buffer, type: .date, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Date.decode(from: &buffer, type: .date, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -81,7 +81,7 @@ class Date_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - XCTAssertThrowsError(try Date.decode(from: &buffer, type: .int8, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Date.decode(from: &buffer, type: .int8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift index 8348c848..5e385de9 100644 --- a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift @@ -9,11 +9,11 @@ class Decimal_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .numeric) var result: Decimal? - XCTAssertNoThrow(result = try Decimal.decode(from: &buffer, type: .numeric, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Decimal.decode(from: &buffer, type: .numeric, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -22,7 +22,7 @@ class Decimal_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - XCTAssertThrowsError(try Decimal.decode(from: &buffer, type: .int8, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Decimal.decode(from: &buffer, type: .int8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift index 108b99ec..5bd6eacb 100644 --- a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift @@ -9,12 +9,12 @@ class Float_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) var result: Double? - XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -24,12 +24,12 @@ class Float_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .float4) XCTAssertEqual(buffer.readableBytes, 4) var result: Float? - XCTAssertNoThrow(result = try Float.decode(from: &buffer, type: .float4, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Float.decode(from: &buffer, type: .float4, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -38,12 +38,12 @@ class Float_PSQLCodableTests: XCTestCase { let value: Double = .nan var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) var result: Double? - XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(result?.isNaN, true) } @@ -51,12 +51,12 @@ class Float_PSQLCodableTests: XCTestCase { let value: Double = .infinity var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) var result: Double? - XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(result?.isInfinite, true) } @@ -65,12 +65,12 @@ class Float_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .float4) XCTAssertEqual(buffer.readableBytes, 4) var result: Double? - XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float4, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float4, format: .binary, context: .default)) XCTAssertEqual(result, Double(value)) } } @@ -80,12 +80,12 @@ class Float_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) var result: Float? - XCTAssertNoThrow(result = try Float.decode(from: &buffer, type: .float8, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Float.decode(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(result, Float(value)) } } @@ -97,22 +97,22 @@ class Float_PSQLCodableTests: XCTestCase { fourByteBuffer.writeInteger(Int32(0)) var toLongBuffer1 = eightByteBuffer - XCTAssertThrowsError(try Double.decode(from: &toLongBuffer1, type: .float4, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Double.decode(from: &toLongBuffer1, type: .float4, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } var toLongBuffer2 = eightByteBuffer - XCTAssertThrowsError(try Float.decode(from: &toLongBuffer2, type: .float4, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Float.decode(from: &toLongBuffer2, type: .float4, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } var toShortBuffer1 = fourByteBuffer - XCTAssertThrowsError(try Double.decode(from: &toShortBuffer1, type: .float8, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Double.decode(from: &toShortBuffer1, type: .float8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } var toShortBuffer2 = fourByteBuffer - XCTAssertThrowsError(try Float.decode(from: &toShortBuffer2, type: .float8, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Float.decode(from: &toShortBuffer2, type: .float8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -122,12 +122,12 @@ class Float_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int64(0)) var copy1 = buffer - XCTAssertThrowsError(try Double.decode(from: ©1, type: .int8, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Double.decode(from: ©1, type: .int8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } var copy2 = buffer - XCTAssertThrowsError(try Float.decode(from: ©2, type: .int8, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Float.decode(from: ©2, type: .int8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index d5ade4c7..04085168 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -15,14 +15,14 @@ class JSON_PSQLCodableTests: XCTestCase { func testRoundTrip() { var buffer = ByteBuffer() let hello = Hello(name: "world") - XCTAssertNoThrow(try hello.encode(into: &buffer, context: .forTests())) + XCTAssertNoThrow(try hello.encode(into: &buffer, context: .default)) XCTAssertEqual(hello.psqlType, .jsonb) // verify jsonb prefix byte XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) var result: Hello? - XCTAssertNoThrow(result = try Hello.decode(from: &buffer, type: .jsonb, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Hello.decode(from: &buffer, type: .jsonb, format: .binary, context: .default)) XCTAssertEqual(result, hello) } @@ -31,7 +31,7 @@ class JSON_PSQLCodableTests: XCTestCase { buffer.writeString(#"{"hello":"world"}"#) var result: Hello? - XCTAssertNoThrow(result = try Hello.decode(from: &buffer, type: .json, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try Hello.decode(from: &buffer, type: .json, format: .binary, context: .default)) XCTAssertEqual(result, Hello(name: "world")) } @@ -45,7 +45,7 @@ class JSON_PSQLCodableTests: XCTestCase { for (format, dataType) in combinations { var loopBuffer = buffer var result: Hello? - XCTAssertNoThrow(result = try Hello.decode(from: &loopBuffer, type: dataType, format: format, context: .forTests())) + XCTAssertNoThrow(result = try Hello.decode(from: &loopBuffer, type: dataType, format: format, context: .default)) XCTAssertEqual(result, Hello(name: "world")) } } @@ -54,7 +54,7 @@ class JSON_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - XCTAssertThrowsError(try Hello.decode(from: &buffer, type: .jsonb, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Hello.decode(from: &buffer, type: .jsonb, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -63,7 +63,7 @@ class JSON_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - XCTAssertThrowsError(try Hello.decode(from: &buffer, type: .text, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try Hello.decode(from: &buffer, type: .text, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift index 1e515f4c..d017d00e 100644 --- a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -15,12 +15,12 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() - XCTAssertNoThrow(try value.encode(into: &buffer, context: .forTests())) + XCTAssertNoThrow(try value.encode(into: &buffer, context: .default)) XCTAssertEqual(value.psqlType, Int16.psqlArrayElementType) XCTAssertEqual(buffer.readableBytes, 2) var result: MyRawRepresentable? - XCTAssertNoThrow(result = try MyRawRepresentable.decode(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try MyRawRepresentable.decode(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -29,7 +29,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int16(4)) // out of bounds - XCTAssertThrowsError(try MyRawRepresentable.decode(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try MyRawRepresentable.decode(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -38,7 +38,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // out of bounds - XCTAssertThrowsError(try MyRawRepresentable.decode(from: &buffer, type: Int32.psqlArrayElementType, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try MyRawRepresentable.decode(from: &buffer, type: Int32.psqlArrayElementType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 12d9d9e2..e4c62704 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -8,7 +8,7 @@ class String_PSQLCodableTests: XCTestCase { let value = "Hello World" var buffer = ByteBuffer() - value.encode(into: &buffer, context: .forTests()) + value.encode(into: &buffer, context: .default) XCTAssertEqual(value.psqlType, .text) XCTAssertEqual(buffer.readString(length: buffer.readableBytes), value) @@ -26,7 +26,7 @@ class String_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer var result: String? - XCTAssertNoThrow(result = try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) + XCTAssertNoThrow(result = try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .default)) XCTAssertEqual(result, expected) } } @@ -37,7 +37,7 @@ class String_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer - XCTAssertThrowsError(try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } @@ -46,21 +46,21 @@ class String_PSQLCodableTests: XCTestCase { func testDecodeFromUUID() { let uuid = UUID() var buffer = ByteBuffer() - uuid.encode(into: &buffer, context: .forTests()) + uuid.encode(into: &buffer, context: .default) var decoded: String? - XCTAssertNoThrow(decoded = try String.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) + XCTAssertNoThrow(decoded = try String.decode(from: &buffer, type: .uuid, format: .binary, context: .default)) XCTAssertEqual(decoded, uuid.uuidString) } func testDecodeFailureFromInvalidUUID() { let uuid = UUID() var buffer = ByteBuffer() - uuid.encode(into: &buffer, context: .forTests()) + uuid.encode(into: &buffer, context: .default) // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - XCTAssertThrowsError(try String.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try String.decode(from: &buffer, type: .uuid, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 5add881a..840b8531 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -9,7 +9,7 @@ class UUID_PSQLCodableTests: XCTestCase { let uuid = UUID() var buffer = ByteBuffer() - uuid.encode(into: &buffer, context: .forTests()) + uuid.encode(into: &buffer, context: .default) XCTAssertEqual(uuid.psqlType, .uuid) XCTAssertEqual(uuid.psqlFormat, .binary) @@ -34,7 +34,7 @@ class UUID_PSQLCodableTests: XCTestCase { XCTAssertEqual(byteIterator.next(), uuid.uuid.15) var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) + XCTAssertNoThrow(decoded = try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .default)) XCTAssertEqual(decoded, uuid) } } @@ -57,7 +57,7 @@ class UUID_PSQLCodableTests: XCTestCase { for (format, dataType) in options { var loopBuffer = lowercaseBuffer var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, format: format, context: .forTests())) + XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, format: format, context: .default)) XCTAssertEqual(decoded, uuid) } @@ -68,7 +68,7 @@ class UUID_PSQLCodableTests: XCTestCase { for (format, dataType) in options { var loopBuffer = uppercaseBuffer var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, format: format, context: .forTests())) + XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, format: format, context: .default)) XCTAssertEqual(decoded, uuid) } } @@ -78,11 +78,11 @@ class UUID_PSQLCodableTests: XCTestCase { let uuid = UUID() var buffer = ByteBuffer() - uuid.encode(into: &buffer, context: .forTests()) + uuid.encode(into: &buffer, context: .default) // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - XCTAssertThrowsError(try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .forTests())) { error in + XCTAssertThrowsError(try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .default)) { error in XCTAssertEqual(error as? PostgresCastingError.Code, .failure) } } @@ -98,7 +98,7 @@ class UUID_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer - XCTAssertThrowsError(try UUID.decode(from: &loopBuffer, type: dataType, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try UUID.decode(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -113,7 +113,7 @@ class UUID_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var copy = buffer - XCTAssertThrowsError(try UUID.decode(from: ©, type: dataType, format: .binary, context: .forTests())) { + XCTAssertThrowsError(try UUID.decode(from: ©, type: dataType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift deleted file mode 100644 index 212a18bd..00000000 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift +++ /dev/null @@ -1,14 +0,0 @@ -@testable import PostgresNIO -import Foundation - -extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { - static func forTests() -> Self { - Self(jsonDecoder: JSONDecoder()) - } -} - -extension PostgresEncodingContext where JSONEncoder == Foundation.JSONEncoder { - static func forTests(jsonEncoder: JSONEncoder = JSONEncoder()) -> Self { - Self(jsonEncoder: jsonEncoder) - } -} diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 6a9dfbb5..5ca43591 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -170,7 +170,7 @@ class PSQLRowStreamTests: XCTestCase { // attach consumer var counter = 0 let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .forTests()), "\(counter)") + XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") counter += 1 } XCTAssertEqual(counter, 2) @@ -214,7 +214,7 @@ class PSQLRowStreamTests: XCTestCase { // attach consumer var counter = 0 let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .forTests()), "\(counter)") + XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") if counter == 1 { throw OnRowError(row: counter) } @@ -261,7 +261,7 @@ class PSQLRowStreamTests: XCTestCase { // attach consumer var counter = 0 let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .forTests()), "\(counter)") + XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") counter += 1 } XCTAssertEqual(counter, 2) diff --git a/Tests/PostgresNIOTests/New/PostgresCellTests.swift b/Tests/PostgresNIOTests/New/PostgresCellTests.swift index 0693f0b1..e7d1cb30 100644 --- a/Tests/PostgresNIOTests/New/PostgresCellTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCellTests.swift @@ -12,7 +12,7 @@ final class PostgresCellTests: XCTestCase { ) var result: String? - XCTAssertNoThrow(result = try cell.decode(String.self, context: .forTests())) + XCTAssertNoThrow(result = try cell.decode(String.self, context: .default)) XCTAssertEqual(result, "Hello world") } @@ -26,7 +26,7 @@ final class PostgresCellTests: XCTestCase { ) var result: String? = "test" - XCTAssertNoThrow(result = try cell.decode(String?.self, context: .forTests())) + XCTAssertNoThrow(result = try cell.decode(String?.self, context: .default)) XCTAssertNil(result) } @@ -39,7 +39,7 @@ final class PostgresCellTests: XCTestCase { columnIndex: 1 ) - XCTAssertThrowsError(try cell.decode(Int?.self, context: .forTests())) { + XCTAssertThrowsError(try cell.decode(Int?.self, context: .default)) { guard let error = $0 as? PostgresCastingError else { return XCTFail("Unexpected error") } diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 9d6c467d..9e01ff06 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -59,7 +59,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .forTests()), counter) + XCTAssertEqual(try row.decode(Int.self, context: .default), counter) counter += 1 if counter == 64 { @@ -141,7 +141,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .forTests()), counter) + XCTAssertEqual(try row.decode(Int.self, context: .default), counter) counter += 1 } @@ -171,7 +171,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .forTests()), counter) + XCTAssertEqual(try row.decode(Int.self, context: .default), counter) counter += 1 } @@ -232,7 +232,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self, context: .forTests()), 0) + XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { stream.receive(completion: .success("SELECT 1")) @@ -266,7 +266,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self, context: .forTests()), 0) + XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { stream.receive(completion: .failure(PSQLError.connectionClosed)) @@ -433,7 +433,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 1 for _ in 0..<(2 * messagePerChunk - 1) { let row = try await rowIterator.next() - XCTAssertEqual(try row?.decode(Int.self, context: .forTests()), counter) + XCTAssertEqual(try row?.decode(Int.self, context: .default), counter) counter += 1 } From 262208c59ad788652eaf26a40f39196062902b71 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 26 Feb 2022 10:18:59 +0100 Subject: [PATCH 067/292] Add an async query API (internal for now) (#233) --- .../Connection/PostgresConnection.swift | 28 ++++++++ Tests/IntegrationTests/AsyncTests.swift | 67 +++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 Tests/IntegrationTests/AsyncTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 0962482d..5377f110 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -92,6 +92,34 @@ extension PostgresConnection { } } +#if swift(>=5.5) && canImport(_Concurrency) +extension PostgresConnection { + func close() async throws { + try await self.close().get() + } + + func query(_ query: PostgresQuery, logger: Logger, file: String = #file, line: UInt = #line) async throws -> PostgresRowSequence { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.underlying.connectionID)" + + do { + guard query.binds.count <= Int(Int16.max) else { + throw PSQLError.tooManyParameters + } + let promise = self.underlying.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + logger: logger, + promise: promise) + + self.underlying.channel.write(PSQLTask.extendedQuery(context), promise: nil) + + return try await promise.futureResult.map({ $0.asyncSequence() }).get() + } + } +} +#endif + // MARK: PostgresDatabase extension PostgresConnection: PostgresDatabase { diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift new file mode 100644 index 00000000..593a06e0 --- /dev/null +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -0,0 +1,67 @@ +import Logging +import XCTest +@testable import PostgresNIO + +#if swift(>=5.5.2) +final class AsyncPostgresConnectionTests: XCTestCase { + + func test1kRoundTrips() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await withTestConnection(on: eventLoop) { connection in + for _ in 0..<1_000 { + let rows = try await connection.query("SELECT version()", logger: .psqlTest) + var iterator = rows.makeAsyncIterator() + let firstRow = try await iterator.next() + XCTAssertEqual(try firstRow?.decode(String.self, context: .default).contains("PostgreSQL"), true) + let done = try await iterator.next() + XCTAssertNil(done) + } + } + } + + func testSelect10kRows() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 1 + for try await row in rows { + XCTAssertEqual(try row.decode(Int.self, context: .default), counter) + counter += 1 + } + + XCTAssertEqual(counter, end + 1) + } + } +} + +extension XCTestCase { + + func withTestConnection( + on eventLoop: EventLoop, + file: StaticString = #file, + line: UInt = #line, + _ closure: (PostgresConnection) async throws -> Result + ) async throws -> Result { + let connection = try await PostgresConnection.test(on: eventLoop).get() + + do { + let result = try await closure(connection) + try await connection.close() + return result + } catch { + XCTFail("Unexpected error: \(error)", file: file, line: line) + try await connection.close() + throw error + } + } +} +#endif From 4cd15673686c971fb77a96b7e6e2466771a224fc Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 26 Feb 2022 10:30:00 +0100 Subject: [PATCH 068/292] Add PostgresRowSequence multi decode (#232) --- .../PostgresRowSequence-multi-decode.swift | 95 +++++++++++++++++++ .../PostgresNIO/New/PostgresRowSequence.swift | 2 +- ...nerate-postgresrowsequence-multi-decode.sh | 73 ++++++++++++++ 3 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift create mode 100755 dev/generate-postgresrowsequence-multi-decode.sh diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift new file mode 100644 index 00000000..aea721e4 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift @@ -0,0 +1,95 @@ +/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh + +#if swift(>=5.5) && canImport(_Concurrency) +extension PostgresRowSequence { + func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode(T0.self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: context, file: file, line: line) + } + } + + func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: context, file: file, line: line) + } + } +} +#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 0a7765a5..a68681fa 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -2,7 +2,7 @@ import NIOCore import NIOConcurrencyHelpers #if swift(>=5.5) && canImport(_Concurrency) -/// An async sequence of ``PSQLRow``s. +/// An async sequence of ``PostgresRow``s. /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. struct PostgresRowSequence: AsyncSequence { diff --git a/dev/generate-postgresrowsequence-multi-decode.sh b/dev/generate-postgresrowsequence-multi-decode.sh new file mode 100755 index 00000000..eb5ad9a0 --- /dev/null +++ b/dev/generate-postgresrowsequence-multi-decode.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +set -eu + +here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +function gen() { + how_many=$1 + + if [[ $how_many -ne 1 ]] ; then + echo "" + fi + + #echo " @inlinable" + #echo " @_alwaysEmitIntoClient" + echo -n " func decode(_: (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) " + + echo -n "-> AsyncThrowingMapSequence {" + + echo " self.map { row in" + + if [[ $how_many -eq 1 ]] ; then + echo " try row.decode(T0.self, context: context, file: file, line: line)" + else + echo -n " try row.decode((T0" + + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$n" + done + echo ").self, context: context, file: file, line: line)" + + fi + + echo " }" + echo " }" +} + +grep -q "ByteBuffer" "${BASH_SOURCE[0]}" || { + echo >&2 "ERROR: ${BASH_SOURCE[0]}: file or directory not found (this should be this script)" + exit 1 +} + +{ +cat <<"EOF" +/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh +EOF +echo + +echo "#if swift(>=5.5) && canImport(_Concurrency)" +echo "extension PostgresRowSequence {" + +# note: +# - widening the inverval below (eg. going from {1..15} to {1..25}) is Semver minor +# - narrowing the interval below is SemVer _MAJOR_! +for n in {1..15}; do + gen "$n" +done +echo "}" +echo "#endif" +} > "$here/../Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift" From cedb5ade2e1af2347015b1d806fa468eef7233c3 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 26 Feb 2022 11:59:00 +0100 Subject: [PATCH 069/292] [BufferMessageEncoder] Reduce the number of force unwraps (#234) --- .../New/BufferedMessageEncoder.swift | 6 +---- .../PostgresNIO/New/PSQLChannelHandler.swift | 24 +++++++++---------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift index 9de1443d..0942b972 100644 --- a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift +++ b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift @@ -28,11 +28,7 @@ struct BufferedMessageEncoder { self.encoder.encode(data: message, out: &self.buffer) } - mutating func flush() -> ByteBuffer? { - guard self.buffer.readableBytes > 0 else { - return nil - } - + mutating func flush() -> ByteBuffer { self.state = .flushed return self.buffer } diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index 575bf02c..c24ee07d 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -229,18 +229,18 @@ final class PSQLChannelHandler: ChannelDuplexHandler { break case .sendStartupMessage(let authContext): self.encoder.encode(.startup(.versionThree(parameters: authContext.toStartupParameters()))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) case .sendSSLRequest: self.encoder.encode(.sslRequest(.init())) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) case .sendPasswordMessage(let mode, let authContext): self.sendPasswordMessage(mode: mode, authContext: authContext, context: context) case .sendSaslInitialResponse(let name, let initialResponse): self.encoder.encode(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) case .sendSaslResponse(let bytes): self.encoder.encode(.saslResponse(.init(data: bytes))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) case .closeConnectionAndCleanup(let cleanupContext): self.closeConnectionAndCleanup(cleanupContext, context: context) case .fireChannelInactive: @@ -304,7 +304,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { // message and immediately closes the connection. On receipt of this message, the // backend closes the connection and terminates. self.encoder.encode(.terminate) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) } context.close(mode: .all, promise: promise) case .succeedPreparedStatementCreation(let preparedContext, with: let rowDescription): @@ -369,11 +369,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler { let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest() self.encoder.encode(.password(.init(value: hash))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) case .cleartext: self.encoder.encode(.password(.init(value: authContext.password ?? ""))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) } } @@ -382,12 +382,12 @@ final class PSQLChannelHandler: ChannelDuplexHandler { case .preparedStatement(let name): self.encoder.encode(.close(.preparedStatement(name))) self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) case .portal(let name): self.encoder.encode(.close(.portal(name))) self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) } } @@ -405,7 +405,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.encoder.encode(.parse(parse)) self.encoder.encode(.describe(.preparedStatement(statementName))) self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) } private func sendBindExecuteAndSyncMessage( @@ -420,7 +420,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.encoder.encode(.bind(bind)) self.encoder.encode(.execute(.init(portalName: ""))) self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) } private func sendParseDescribeBindExecuteAndSyncMessage( @@ -443,7 +443,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.encoder.encode(.bind(bind)) self.encoder.encode(.execute(.init(portalName: ""))) self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) } private func succeedQueryWithRowStream( From 967bf01e9dbebb787feb4f31ce95439f13612b14 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 26 Feb 2022 12:28:14 +0100 Subject: [PATCH 070/292] [PostgresRowSequence] Make StateMachine private (#235) --- .../PostgresNIO/New/PostgresRowSequence.swift | 179 +++++++++--------- 1 file changed, 89 insertions(+), 90 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index a68681fa..4a87b452 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -8,31 +8,31 @@ import NIOConcurrencyHelpers struct PostgresRowSequence: AsyncSequence { typealias Element = PostgresRow typealias AsyncIterator = Iterator - + final class _Internal { - + let consumer: AsyncStreamConsumer - + init(consumer: AsyncStreamConsumer) { self.consumer = consumer } - + deinit { // if no iterator was created, we need to cancel the stream self.consumer.sequenceDeinitialized() } - + func makeAsyncIterator() -> Iterator { self.consumer.makeAsyncIterator() } } - + let _internal: _Internal - + init(_ consumer: AsyncStreamConsumer) { self._internal = .init(consumer: consumer) } - + func makeAsyncIterator() -> Iterator { self._internal.makeAsyncIterator() } @@ -41,20 +41,20 @@ struct PostgresRowSequence: AsyncSequence { extension PostgresRowSequence { struct Iterator: AsyncIteratorProtocol { typealias Element = PostgresRow - + let _internal: _Internal - + init(consumer: AsyncStreamConsumer) { self._internal = _Internal(consumer: consumer) } - + mutating func next() async throws -> PostgresRow? { try await self._internal.next() } - + final class _Internal { let consumer: AsyncStreamConsumer - + init(consumer: AsyncStreamConsumer) { self.consumer = consumer } @@ -62,7 +62,7 @@ extension PostgresRowSequence { deinit { self.consumer.iteratorDeinitialized() } - + func next() async throws -> PostgresRow? { try await self.consumer.next() } @@ -72,44 +72,44 @@ extension PostgresRowSequence { final class AsyncStreamConsumer { let lock = Lock() - + let lookupTable: [String: Int] let columns: [RowDescription.Column] private var state: StateMachine - + init( lookupTable: [String: Int], columns: [RowDescription.Column] ) { self.state = StateMachine() - + self.lookupTable = lookupTable self.columns = columns } - + func startCompleted(_ buffer: CircularBuffer, commandTag: String) { self.lock.withLock { self.state.finished(buffer, commandTag: commandTag) } } - + func startStreaming(_ buffer: CircularBuffer, upstream: PSQLRowStream) { self.lock.withLock { self.state.buffered(buffer, upstream: upstream) } } - + func startFailed(_ error: Error) { self.lock.withLock { self.state.failed(error) } } - + func receive(_ newRows: [DataRow]) { let receiveAction = self.lock.withLock { self.state.receive(newRows) } - + switch receiveAction { case .succeed(let continuation, let data, signalDemandTo: let source): let row = PostgresRow( @@ -119,34 +119,34 @@ final class AsyncStreamConsumer { ) continuation.resume(returning: row) source?.demand() - + case .none: break } } - + func receive(completion result: Result) { let completionAction = self.lock.withLock { self.state.receive(completion: result) } - + switch completionAction { case .succeed(let continuation): continuation.resume(returning: nil) - + case .fail(let continuation, let error): continuation.resume(throwing: error) - + case .none: break } } - + func sequenceDeinitialized() { let action = self.lock.withLock { self.state.sequenceDeinitialized() } - + switch action { case .cancelStream(let source): source.cancel() @@ -154,7 +154,7 @@ final class AsyncStreamConsumer { break } } - + func makeAsyncIterator() -> PostgresRowSequence.Iterator { self.lock.withLock { self.state.createAsyncIterator() @@ -182,7 +182,7 @@ final class AsyncStreamConsumer { case .returnNil: self.lock.unlock() return nil - + case .returnRow(let data, signalDemandTo: let source): self.lock.unlock() source?.demand() @@ -191,11 +191,11 @@ final class AsyncStreamConsumer { lookupTable: self.lookupTable, columns: self.columns ) - + case .throwError(let error): self.lock.unlock() throw error - + case .hitSlowPath: return try await withCheckedThrowingContinuation { continuation in let slowPathAction = self.state.next(for: continuation) @@ -213,13 +213,13 @@ final class AsyncStreamConsumer { } extension AsyncStreamConsumer { - struct StateMachine { - enum UpstreamState { + private struct StateMachine { + private enum UpstreamState { enum DemandState { case canAskForMore case waitingForMore(CheckedContinuation?) } - + case initialized /// The upstream has more data that can be received case streaming(AdaptiveRowBuffer, PSQLRowStream, DemandState) @@ -234,17 +234,17 @@ extension AsyncStreamConsumer { /// `.streaming` or `.finished` state. case modifying } - - enum DownstreamState { + + private enum DownstreamState { case sequenceCreated case iteratorCreated } - - var upstreamState = UpstreamState.initialized - var downstreamState = DownstreamState.sequenceCreated - + + private var upstreamState = UpstreamState.initialized + private var downstreamState = DownstreamState.sequenceCreated + init() {} - + mutating func buffered(_ buffer: CircularBuffer, upstream: PSQLRowStream) { switch self.upstreamState { case .initialized: @@ -255,7 +255,7 @@ extension AsyncStreamConsumer { preconditionFailure("Invalid upstream state: \(self.upstreamState)") } } - + mutating func finished(_ buffer: CircularBuffer, commandTag: String) { switch self.upstreamState { case .initialized: @@ -266,7 +266,7 @@ extension AsyncStreamConsumer { preconditionFailure("Invalid upstream state: \(self.upstreamState)") } } - + mutating func failed(_ error: Error) { switch self.upstreamState { case .initialized: @@ -276,7 +276,7 @@ extension AsyncStreamConsumer { preconditionFailure("Invalid upstream state: \(self.upstreamState)") } } - + mutating func createAsyncIterator() { switch self.downstreamState { case .sequenceCreated: @@ -285,28 +285,28 @@ extension AsyncStreamConsumer { preconditionFailure("An iterator already exists") } } - + enum SequenceDeinitializedAction { case cancelStream(PSQLRowStream) case none } - + mutating func sequenceDeinitialized() -> SequenceDeinitializedAction { switch (self.downstreamState, self.upstreamState) { case (.sequenceCreated, .initialized): preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") - + case (.sequenceCreated, .streaming(_, let source, _)): return .cancelStream(source) - + case (.sequenceCreated, .finished), (.sequenceCreated, .consumed), (.sequenceCreated, .failed): return .none - + case (.iteratorCreated, _): return .none - + case (_, .modifying): preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") } @@ -331,14 +331,13 @@ extension AsyncStreamConsumer { } } - enum NextFastPathAction { case hitSlowPath case throwError(Error) case returnRow(DataRow, signalDemandTo: PSQLRowStream?) case returnNil } - + mutating func next() -> NextFastPathAction { switch self.upstreamState { case .initialized: @@ -363,7 +362,7 @@ extension AsyncStreamConsumer { self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) return .hitSlowPath } - + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) return .returnRow(data, signalDemandTo: nil) @@ -376,7 +375,7 @@ extension AsyncStreamConsumer { self.upstreamState = .consumed return .returnNil } - + self.upstreamState = .finished(buffer, commandTag) return .returnRow(data, signalDemandTo: nil) @@ -396,41 +395,41 @@ extension AsyncStreamConsumer { case signalDemand(PSQLRowStream) case none } - + mutating func next(for continuation: CheckedContinuation) -> NextSlowPathAction { switch self.upstreamState { case .initialized: preconditionFailure() - + case .streaming(let buffer, let source, .canAskForMore): precondition(buffer.isEmpty) self.upstreamState = .streaming(buffer, source, .waitingForMore(continuation)) return .signalDemand(source) - + case .streaming(let buffer, let source, .waitingForMore(.none)): precondition(buffer.isEmpty) self.upstreamState = .streaming(buffer, source, .waitingForMore(continuation)) return .none - + case .streaming(_, _, .waitingForMore(.some)), .finished, .failed, .consumed: preconditionFailure("Expected that state was already handled by fast path. Invalid upstream state: \(self.upstreamState)") - + case .modifying: preconditionFailure("Invalid upstream state: \(self.upstreamState)") } } - + enum ReceiveAction { case succeed(CheckedContinuation, DataRow, signalDemandTo: PSQLRowStream?) case none } - + mutating func receive(_ newRows: [DataRow]) -> ReceiveAction { precondition(!newRows.isEmpty) - + switch self.upstreamState { case .streaming(var buffer, let source, .waitingForMore(.some(let continuation))): buffer.append(contentsOf: newRows) @@ -441,34 +440,34 @@ extension AsyncStreamConsumer { } self.upstreamState = .streaming(buffer, source, .canAskForMore) return .succeed(continuation, first, signalDemandTo: nil) - + case .streaming(var buffer, let source, .waitingForMore(.none)): buffer.append(contentsOf: newRows) self.upstreamState = .streaming(buffer, source, .canAskForMore) return .none - + case .streaming(var buffer, let source, .canAskForMore): buffer.append(contentsOf: newRows) self.upstreamState = .streaming(buffer, source, .canAskForMore) return .none - + case .initialized, .finished, .consumed: preconditionFailure() - + case .failed: return .none - + case .modifying: preconditionFailure() } } - + enum CompletionResult { case succeed(CheckedContinuation) case fail(CheckedContinuation, Error) case none } - + mutating func receive(completion result: Result) -> CompletionResult { switch result { case .success(let commandTag): @@ -477,54 +476,54 @@ extension AsyncStreamConsumer { return self.receiveError(error) } } - + private mutating func receiveEnd(commandTag: String) -> CompletionResult { switch self.upstreamState { case .streaming(let buffer, _, .waitingForMore(.some(let continuation))): precondition(buffer.isEmpty) self.upstreamState = .consumed return .succeed(continuation) - + case .streaming(let buffer, _, .waitingForMore(.none)): self.upstreamState = .finished(buffer, commandTag) return .none - + case .streaming(let buffer, _, .canAskForMore): self.upstreamState = .finished(buffer, commandTag) return .none - + case .initialized, .finished, .consumed: preconditionFailure("Invalid upstream state: \(self.upstreamState)") - + case .failed: return .none - + case .modifying: preconditionFailure() } } - + private mutating func receiveError(_ error: Error) -> CompletionResult { switch self.upstreamState { case .streaming(let buffer, _, .waitingForMore(.some(let continuation))): precondition(buffer.isEmpty) self.upstreamState = .consumed return .fail(continuation, error) - + case .streaming(_, _, .waitingForMore(.none)): self.upstreamState = .failed(error) return .none - + case .streaming(_, _, .canAskForMore): self.upstreamState = .failed(error) return .none - + case .initialized, .finished, .consumed: preconditionFailure("Invalid upstream state: \(self.upstreamState)") - + case .failed: return .none - + case .modifying: preconditionFailure() } @@ -553,11 +552,11 @@ struct AdaptiveRowBuffer { private var circularBuffer: CircularBuffer private var target: Int private var canShrink: Bool = false - + var isEmpty: Bool { self.circularBuffer.isEmpty } - + init(minimum: Int, maximum: Int, target: Int, buffer: CircularBuffer) { precondition(minimum <= target && target <= maximum) self.minimum = minimum @@ -565,7 +564,7 @@ struct AdaptiveRowBuffer { self.target = target self.circularBuffer = buffer } - + init(_ circularBuffer: CircularBuffer) { self.init( minimum: Self.defaultBufferMinimum, @@ -574,7 +573,7 @@ struct AdaptiveRowBuffer { buffer: circularBuffer ) } - + mutating func append(contentsOf newRows: Rows) where Rows.Element == DataRow { self.circularBuffer.append(contentsOf: newRows) if self.circularBuffer.count >= self.target, self.canShrink, self.target > self.minimum { @@ -586,16 +585,16 @@ struct AdaptiveRowBuffer { /// Returns the next row in the FIFO buffer and a `bool` signalling if new rows should be loaded. mutating func removeFirst() -> (DataRow, Bool) { let element = self.circularBuffer.removeFirst() - + // If the buffer is drained now, we should double our target size. if self.circularBuffer.count == 0, self.target < self.maximum { self.target = self.target * 2 self.canShrink = false } - + return (element, self.circularBuffer.count < self.target) } - + mutating func popFirst() -> (DataRow, Bool)? { guard !self.circularBuffer.isEmpty else { return nil From 63e7d57039627252f5fdac9adef98a7018f82c18 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 26 Feb 2022 13:02:50 +0100 Subject: [PATCH 071/292] Rename PSQLRow multi decode to Postgres (#236) --- ...LRow-multi-decode.swift => PostgresRow-multi-decode.swift} | 2 +- ...w-multi-decode.sh => generate-postgresrow-multi-decode.sh} | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename Sources/PostgresNIO/New/{PSQLRow-multi-decode.swift => PostgresRow-multi-decode.swift} (99%) rename dev/{generate-psqlrow-multi-decode.sh => generate-postgresrow-multi-decode.sh} (95%) diff --git a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift similarity index 99% rename from Sources/PostgresNIO/New/PSQLRow-multi-decode.swift rename to Sources/PostgresNIO/New/PostgresRow-multi-decode.swift index ef67c7ac..1e1a426d 100644 --- a/Sources/PostgresNIO/New/PSQLRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift @@ -1,4 +1,4 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-psqlrow-multi-decode.sh +/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh extension PostgresRow { func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0) { diff --git a/dev/generate-psqlrow-multi-decode.sh b/dev/generate-postgresrow-multi-decode.sh similarity index 95% rename from dev/generate-psqlrow-multi-decode.sh rename to dev/generate-postgresrow-multi-decode.sh index 84652339..b99be562 100755 --- a/dev/generate-psqlrow-multi-decode.sh +++ b/dev/generate-postgresrow-multi-decode.sh @@ -88,7 +88,7 @@ grep -q "ByteBuffer" "${BASH_SOURCE[0]}" || { { cat <<"EOF" -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-psqlrow-multi-decode.sh +/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh EOF echo @@ -101,4 +101,4 @@ for n in {1..15}; do gen "$n" done echo "}" -} > "$here/../Sources/PostgresNIO/New/PSQLRow-multi-decode.swift" +} > "$here/../Sources/PostgresNIO/New/PostgresRow-multi-decode.swift" From 5dade1c2410bdc6995d7c182f0b33aa349c376bc Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 2 Mar 2022 07:35:35 +0100 Subject: [PATCH 072/292] Explicit TLS config. (#237) --- .../Connection/PostgresConnection.swift | 34 +++++++++----- .../ConnectionStateMachine.swift | 39 +++++++++++---- .../PostgresNIO/New/PSQLChannelHandler.swift | 31 +++++++++++- Sources/PostgresNIO/New/PSQLConnection.swift | 47 +++++++++++++++---- .../PSQLIntegrationTests.swift | 5 +- .../AuthenticationStateMachineTests.swift | 26 ++++++---- .../ConnectionStateMachineTests.swift | 19 +++++--- .../ConnectionAction+TestUtils.swift | 3 +- .../New/PSQLChannelHandlerTests.swift | 8 ++-- .../New/PSQLConnectionTests.swift | 3 +- 10 files changed, 161 insertions(+), 54 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 5377f110..d6a26e5e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -54,17 +54,29 @@ extension PostgresConnection { logger: Logger = .init(label: "codes.vapor.postgres"), on eventLoop: EventLoop ) -> EventLoopFuture { - let configuration = PSQLConnection.Configuration( - connection: .resolved(address: socketAddress, serverName: serverHostname), - authentication: nil, - tlsConfiguration: tlsConfiguration - ) - - return PSQLConnection.connect( - configuration: configuration, - logger: logger, - on: eventLoop - ).map { connection in + var tlsFuture: EventLoopFuture + + if let tlsConfiguration = tlsConfiguration { + tlsFuture = eventLoop.makeSucceededVoidFuture().flatMapBlocking(onto: .global(qos: .default)) { + try PSQLConnection.Configuration.TLS.require(.init(configuration: tlsConfiguration)) + } + } else { + tlsFuture = eventLoop.makeSucceededFuture(.disable) + } + + return tlsFuture.flatMap { tls in + let configuration = PSQLConnection.Configuration( + connection: .resolved(address: socketAddress, serverName: serverHostname), + authentication: nil, + tls: tls + ) + + return PSQLConnection.connect( + configuration: configuration, + logger: logger, + on: eventLoop + ) + }.map { connection in PostgresConnection(underlying: connection, logger: logger) }.flatMapErrorThrowing { error in throw error.asAppropriatePostgresError diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 36bcdf39..82db845f 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -18,8 +18,13 @@ struct ConnectionStateMachine { } enum State { + enum TLSConfiguration { + case prefer + case require + } + case initialized - case sslRequestSent + case sslRequestSent(TLSConfiguration) case sslNegotiated case sslHandlerAdded case waitingToStartAuthentication @@ -114,26 +119,38 @@ struct ConnectionStateMachine { init() { self.state = .initialized } - + #if DEBUG /// for testing purposes only init(_ state: State) { self.state = state } #endif + + enum TLSConfiguration { + case disable + case prefer + case require + } - mutating func connected(requireTLS: Bool) -> ConnectionAction { + mutating func connected(tls: TLSConfiguration) -> ConnectionAction { guard case .initialized = self.state else { preconditionFailure("Unexpected state") } - if requireTLS { - self.state = .sslRequestSent + switch tls { + case .disable: + self.state = .waitingToStartAuthentication + return .provideAuthenticationContext + + case .prefer: + self.state = .sslRequestSent(.prefer) + return .sendSSLRequest + + case .require: + self.state = .sslRequestSent(.require) return .sendSSLRequest } - - self.state = .waitingToStartAuthentication - return .provideAuthenticationContext } mutating func provideAuthenticationContext(_ authContext: AuthContext) -> ConnectionAction { @@ -223,8 +240,12 @@ struct ConnectionStateMachine { mutating func sslUnsupportedReceived() -> ConnectionAction { switch self.state { - case .sslRequestSent: + case .sslRequestSent(.require): return self.closeConnectionAndCleanup(.sslUnsupported) + + case .sslRequestSent(.prefer): + self.state = .waitingToStartAuthentication + return .provideAuthenticationContext case .initialized, .sslNegotiated, diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index c24ee07d..0862c517 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -57,7 +57,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder()) } #endif - + // MARK: Handler lifecycle func handlerAdded(context: ChannelHandlerContext) { @@ -331,7 +331,8 @@ final class PSQLChannelHandler: ChannelDuplexHandler { // MARK: - Private Methods - private func connected(context: ChannelHandlerContext) { - let action = self.state.connected(requireTLS: self.configureSSLCallback != nil) + + let action = self.state.connected(tls: .init(self.configuration.tls)) self.run(action, with: context) } @@ -572,3 +573,29 @@ private extension Insecure.MD5.Digest { return String(decoding: result, as: Unicode.UTF8.self) } } + +extension ConnectionStateMachine.TLSConfiguration { + fileprivate init(_ connection: PSQLConnection.Configuration.TLS) { + switch connection.base { + case .disable: + self = .disable + case .require: + self = .require + case .prefer: + self = .prefer + } + } +} + +extension PSQLChannelHandler { + convenience init( + configuration: PSQLConnection.Configuration, + configureSSLCallback: ((Channel) throws -> Void)?) + { + self.init( + configuration: configuration, + logger: .psqlNoOpLogger, + configureSSLCallback: configureSSLCallback + ) + } +} diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index 2ebb2bba..0b1ce1ab 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -23,6 +23,30 @@ final class PSQLConnection { self.password = password } } + + struct TLS { + enum Base { + case disable + case prefer(NIOSSLContext) + case require(NIOSSLContext) + } + + var base: Base + + private init(_ base: Base) { + self.base = base + } + + static var disable: Self = Self.init(.disable) + + static func prefer(_ sslContext: NIOSSLContext) -> Self { + self.init(.prefer(sslContext)) + } + + static func require(_ sslContext: NIOSSLContext) -> Self { + self.init(.require(sslContext)) + } + } enum Connection { case unresolved(host: String, port: Int) @@ -34,27 +58,27 @@ final class PSQLConnection { /// The authentication properties to send to the Postgres server during startup auth handshake var authentication: Authentication? - var tlsConfiguration: TLSConfiguration? + var tls: TLS init(host: String, port: Int = 5432, username: String, database: String? = nil, password: String? = nil, - tlsConfiguration: TLSConfiguration? = nil + tls: TLS = .disable ) { self.connection = .unresolved(host: host, port: port) self.authentication = Authentication(username: username, password: password, database: database) - self.tlsConfiguration = tlsConfiguration + self.tls = tls } init(connection: Connection, authentication: Authentication?, - tlsConfiguration: TLSConfiguration? + tls: TLS ) { self.connection = connection self.authentication = authentication - self.tlsConfiguration = tlsConfiguration + self.tls = tls } } @@ -185,14 +209,19 @@ final class PSQLConnection { let bootstrap = ClientBootstrap(group: eventLoop) .channelInitializer { channel in var configureSSLCallback: ((Channel) throws -> ())? = nil - if let tlsConfiguration = configuration.tlsConfiguration { + + switch configuration.tls.base { + case .disable: + break + + case .prefer(let sslContext), .require(let sslContext): configureSSLCallback = { channel in channel.eventLoop.assertInEventLoop() - - let sslContext = try NIOSSLContext(configuration: tlsConfiguration) + let sslHandler = try NIOSSLClientHandler( context: sslContext, - serverHostname: configuration.sslServerHostname) + serverHostname: configuration.sslServerHostname + ) try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) } } diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 16d720f7..6dce981c 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -28,7 +28,7 @@ final class IntegrationTests: XCTestCase { username: env("POSTGRES_USER") ?? "test_username", database: env("POSTGRES_DB") ?? "test_database", password: "wrong_password", - tlsConfiguration: nil) + tls: .disable) let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -358,7 +358,8 @@ extension PSQLConnection { username: env("POSTGRES_USER") ?? "test_username", database: env("POSTGRES_DB") ?? "test_database", password: env("POSTGRES_PASSWORD") ?? "test_password", - tlsConfiguration: nil) + tls: .disable + ) return PSQLConnection.connect(configuration: config, logger: logger, on: eventLoop) } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index b503f1ad..2ed28c20 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -6,7 +6,9 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticatePlaintext() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine(.waitingToStartAuthentication) + + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext)) @@ -15,7 +17,8 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticateMD5() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine(.waitingToStartAuthentication) + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) @@ -25,7 +28,8 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticateMD5WithoutPassword() { let authContext = AuthContext(username: "test", password: nil, database: "test") - var state = ConnectionStateMachine(.waitingToStartAuthentication) + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) @@ -35,15 +39,16 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticateOkAfterStartUpWithoutAuthChallenge() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine(.waitingToStartAuthentication) - + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) } func testAuthenticationFailure() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine(.waitingToStartAuthentication) + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) @@ -74,7 +79,8 @@ class AuthenticationStateMachineTests: XCTestCase { for (message, mechanism) in unsupported { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine(.waitingToStartAuthentication) + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(message), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unsupportedAuthMechanism(mechanism), closePromise: nil))) @@ -92,7 +98,8 @@ class AuthenticationStateMachineTests: XCTestCase { for message in unexpected { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine(.waitingToStartAuthentication) + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(message), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil))) @@ -118,7 +125,8 @@ class AuthenticationStateMachineTests: XCTestCase { for message in unexpected { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine(.waitingToStartAuthentication) + var state = ConnectionStateMachine() + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) XCTAssertEqual(state.authenticationMessageReceived(message), diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 63d40e1a..5b7ed388 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -9,7 +9,7 @@ class ConnectionStateMachineTests: XCTestCase { func testStartup() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine() - XCTAssertEqual(state.connected(requireTLS: false), .provideAuthenticationContext) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext)) XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) @@ -18,7 +18,7 @@ class ConnectionStateMachineTests: XCTestCase { func testSSLStartupSuccess() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine() - XCTAssertEqual(state.connected(requireTLS: true), .sendSSLRequest) + XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection) XCTAssertEqual(state.sslHandlerAdded(), .wait) XCTAssertEqual(state.sslEstablished(), .provideAuthenticationContext) @@ -31,19 +31,26 @@ class ConnectionStateMachineTests: XCTestCase { struct SSLHandlerAddError: Error, Equatable {} var state = ConnectionStateMachine() - XCTAssertEqual(state.connected(requireTLS: true), .sendSSLRequest) + XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection) let failError = PSQLError.failedToAddSSLHandler(underlying: SSLHandlerAddError()) XCTAssertEqual(state.errorHappened(failError), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) } - func testSSLStartupSSLUnsupported() { + func testTLSRequiredStartupSSLUnsupported() { var state = ConnectionStateMachine() - XCTAssertEqual(state.connected(requireTLS: true), .sendSSLRequest) + XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) XCTAssertEqual(state.sslUnsupportedReceived(), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.sslUnsupported, closePromise: nil))) } + + func testTLSPreferredStartupSSLUnsupported() { + var state = ConnectionStateMachine() + + XCTAssertEqual(state.connected(tls: .prefer), .sendSSLRequest) + XCTAssertEqual(state.sslUnsupportedReceived(), .provideAuthenticationContext) + } func testParameterStatusReceivedAndBackendKeyAfterAuthenticated() { var state = ConnectionStateMachine(.authenticated(nil, [:])) @@ -133,7 +140,7 @@ class ConnectionStateMachineTests: XCTestCase { promise: queryPromise) XCTAssertEqual(state.enqueue(task: .extendedQuery(extendedQueryContext)), .wait) - XCTAssertEqual(state.connected(requireTLS: false), .provideAuthenticationContext) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) let fields: [PSQLBackendMessage.Field: String] = [ diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index 6db93101..13323e76 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -91,7 +91,8 @@ extension ConnectionStateMachine { processID: 2730, secretKey: 882037977, parameters: paramaters, - transactionState: transactionState) + transactionState: transactionState + ) } } diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index f47a0071..6927c50e 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -38,7 +38,7 @@ class PSQLChannelHandlerTests: XCTestCase { func testEstablishSSLCallbackIsCalledIfSSLIsSupported() { var config = self.testConnectionConfiguration() - config.tlsConfiguration = .makeClientConfiguration() + XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false let handler = PSQLChannelHandler(configuration: config) { channel in addSSLCallbackIsHit = true @@ -80,7 +80,7 @@ class PSQLChannelHandlerTests: XCTestCase { func testSSLUnsupportedClosesConnection() { var config = self.testConnectionConfiguration() - config.tlsConfiguration = .makeClientConfiguration() + XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) let handler = PSQLChannelHandler(configuration: config) { channel in XCTFail("This callback should never be exectuded") @@ -173,7 +173,7 @@ class PSQLChannelHandlerTests: XCTestCase { username: String = "test", database: String = "postgres", password: String = "password", - tlsConfiguration: TLSConfiguration? = nil + tls: PSQLConnection.Configuration.TLS = .disable ) -> PSQLConnection.Configuration { PSQLConnection.Configuration( host: host, @@ -181,7 +181,7 @@ class PSQLChannelHandlerTests: XCTestCase { username: username, database: database, password: password, - tlsConfiguration: tlsConfiguration + tls: tls ) } } diff --git a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift index 708c6c0e..a0b68cea 100644 --- a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift @@ -27,7 +27,8 @@ class PSQLConnectionTests: XCTestCase { username: "postgres", database: "postgres", password: "abc123", - tlsConfiguration: nil) + tls: .disable + ) var logger = Logger.psqlTest logger.logLevel = .trace From 13f362f3485258b0dc1faca0967500a88c07d5fe Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 2 Mar 2022 08:01:01 +0100 Subject: [PATCH 073/292] Rename PSQLFrontendMessage to PostgresFrontendMessage (#239) --- .../New/BufferedMessageEncoder.swift | 2 +- .../New/Extensions/ByteBuffer+PSQL.swift | 2 +- Sources/PostgresNIO/New/Messages/Bind.swift | 2 +- Sources/PostgresNIO/New/Messages/Cancel.swift | 2 +- Sources/PostgresNIO/New/Messages/Close.swift | 2 +- .../PostgresNIO/New/Messages/Describe.swift | 2 +- .../PostgresNIO/New/Messages/Execute.swift | 2 +- Sources/PostgresNIO/New/Messages/Parse.swift | 2 +- .../PostgresNIO/New/Messages/Password.swift | 2 +- .../New/Messages/SASLInitialResponse.swift | 2 +- .../New/Messages/SASLResponse.swift | 2 +- .../PostgresNIO/New/Messages/SSLRequest.swift | 2 +- .../PostgresNIO/New/Messages/Startup.swift | 2 +- .../PostgresNIO/New/PSQLChannelHandler.swift | 12 ++++---- .../New/PSQLFrontendMessageEncoder.swift | 6 ++-- ...ge.swift => PostgresFrontendMessage.swift} | 4 +-- .../PSQLFrontendMessageDecoder.swift | 18 ++++++------ .../New/Messages/BindTests.swift | 6 ++-- .../New/Messages/CancelTests.swift | 4 +-- .../New/Messages/CloseTests.swift | 8 +++--- .../New/Messages/DescribeTests.swift | 8 +++--- .../New/Messages/ExecuteTests.swift | 4 +-- .../New/Messages/ParseTests.swift | 6 ++-- .../New/Messages/PasswordTests.swift | 4 +-- .../Messages/SASLInitialResponseTests.swift | 12 ++++---- .../New/Messages/SASLResponseTests.swift | 12 ++++---- .../New/Messages/SSLRequestTests.swift | 4 +-- .../New/Messages/StartupTests.swift | 10 +++---- .../New/PSQLChannelHandlerTests.swift | 26 ++++++++--------- .../New/PSQLFrontendMessageTests.swift | 28 +++++++++---------- 30 files changed, 99 insertions(+), 99 deletions(-) rename Sources/PostgresNIO/New/{PSQLFrontendMessage.swift => PostgresFrontendMessage.swift} (97%) diff --git a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift index 0942b972..f202fcff 100644 --- a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift +++ b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift @@ -15,7 +15,7 @@ struct BufferedMessageEncoder { self.encoder = encoder } - mutating func encode(_ message: PSQLFrontendMessage) { + mutating func encode(_ message: PostgresFrontendMessage) { switch self.state { case .flushed: self.state = .writable diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index a948b41b..6793d50e 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -6,7 +6,7 @@ internal extension ByteBuffer { self.writeInteger(messageID.rawValue) } - mutating func psqlWriteFrontendMessageID(_ messageID: PSQLFrontendMessage.ID) { + mutating func psqlWriteFrontendMessageID(_ messageID: PostgresFrontendMessage.ID) { self.writeInteger(messageID.rawValue) } diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index 74868b4c..9fc0445e 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { struct Bind: PSQLMessagePayloadEncodable, Equatable { /// The name of the destination portal (an empty string selects the unnamed portal). diff --git a/Sources/PostgresNIO/New/Messages/Cancel.swift b/Sources/PostgresNIO/New/Messages/Cancel.swift index 64107d7a..2f29d239 100644 --- a/Sources/PostgresNIO/New/Messages/Cancel.swift +++ b/Sources/PostgresNIO/New/Messages/Cancel.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { struct Cancel: PSQLMessagePayloadEncodable, Equatable { /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, diff --git a/Sources/PostgresNIO/New/Messages/Close.swift b/Sources/PostgresNIO/New/Messages/Close.swift index 5ed532e6..7f038f94 100644 --- a/Sources/PostgresNIO/New/Messages/Close.swift +++ b/Sources/PostgresNIO/New/Messages/Close.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { enum Close: PSQLMessagePayloadEncodable, Equatable { case preparedStatement(String) diff --git a/Sources/PostgresNIO/New/Messages/Describe.swift b/Sources/PostgresNIO/New/Messages/Describe.swift index 0a3105cc..76167d32 100644 --- a/Sources/PostgresNIO/New/Messages/Describe.swift +++ b/Sources/PostgresNIO/New/Messages/Describe.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { enum Describe: PSQLMessagePayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/Execute.swift b/Sources/PostgresNIO/New/Messages/Execute.swift index 891bd9aa..17646484 100644 --- a/Sources/PostgresNIO/New/Messages/Execute.swift +++ b/Sources/PostgresNIO/New/Messages/Execute.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { struct Execute: PSQLMessagePayloadEncodable, Equatable { /// The name of the portal to execute (an empty string selects the unnamed portal). diff --git a/Sources/PostgresNIO/New/Messages/Parse.swift b/Sources/PostgresNIO/New/Messages/Parse.swift index fa20c7bd..268ad4ff 100644 --- a/Sources/PostgresNIO/New/Messages/Parse.swift +++ b/Sources/PostgresNIO/New/Messages/Parse.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { struct Parse: PSQLMessagePayloadEncodable, Equatable { /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). diff --git a/Sources/PostgresNIO/New/Messages/Password.swift b/Sources/PostgresNIO/New/Messages/Password.swift index 88e885f9..81d7ab30 100644 --- a/Sources/PostgresNIO/New/Messages/Password.swift +++ b/Sources/PostgresNIO/New/Messages/Password.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { struct Password: PSQLMessagePayloadEncodable, Equatable { let value: String diff --git a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift b/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift index ead609c7..73db9332 100644 --- a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift +++ b/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { struct SASLInitialResponse: PSQLMessagePayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/SASLResponse.swift b/Sources/PostgresNIO/New/Messages/SASLResponse.swift index dc49a506..a6709dcd 100644 --- a/Sources/PostgresNIO/New/Messages/SASLResponse.swift +++ b/Sources/PostgresNIO/New/Messages/SASLResponse.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { struct SASLResponse: PSQLMessagePayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/SSLRequest.swift b/Sources/PostgresNIO/New/Messages/SSLRequest.swift index f67f25fe..6f9c45a3 100644 --- a/Sources/PostgresNIO/New/Messages/SSLRequest.swift +++ b/Sources/PostgresNIO/New/Messages/SSLRequest.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { /// A message asking the PostgreSQL server if TLS is supported /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 struct SSLRequest: PSQLMessagePayloadEncodable, Equatable { diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift index 6e991928..f7da2127 100644 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ b/Sources/PostgresNIO/New/Messages/Startup.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { struct Startup: PSQLMessagePayloadEncodable, Equatable { /// Creates a `Startup` with "3.0" as the protocol version. diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index 0862c517..f0671bca 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -398,7 +398,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { context: ChannelHandlerContext) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") - let parse = PSQLFrontendMessage.Parse( + let parse = PostgresFrontendMessage.Parse( preparedStatementName: statementName, query: query, parameters: []) @@ -413,7 +413,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { executeStatement: PSQLExecuteStatement, context: ChannelHandlerContext ) { - let bind = PSQLFrontendMessage.Bind( + let bind = PostgresFrontendMessage.Bind( portalName: "", preparedStatementName: executeStatement.name, bind: executeStatement.binds) @@ -430,11 +430,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler { { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") let unnamedStatementName = "" - let parse = PSQLFrontendMessage.Parse( + let parse = PostgresFrontendMessage.Parse( preparedStatementName: unnamedStatementName, query: query.sql, parameters: query.binds.metadata.map(\.dataType)) - let bind = PSQLFrontendMessage.Bind( + let bind = PostgresFrontendMessage.Bind( portalName: "", preparedStatementName: unnamedStatementName, bind: query.binds) @@ -528,8 +528,8 @@ extension PSQLConnection.Configuration.Authentication { } extension AuthContext { - func toStartupParameters() -> PSQLFrontendMessage.Startup.Parameters { - PSQLFrontendMessage.Startup.Parameters( + func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters { + PostgresFrontendMessage.Startup.Parameters( user: self.username, database: self.database, options: nil, diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift index 92ffeb07..8447c683 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift @@ -1,10 +1,10 @@ struct PSQLFrontendMessageEncoder: MessageToByteEncoder { - typealias OutboundIn = PSQLFrontendMessage + typealias OutboundIn = PostgresFrontendMessage init() {} - func encode(data message: PSQLFrontendMessage, out buffer: inout ByteBuffer) { + func encode(data message: PostgresFrontendMessage, out buffer: inout ByteBuffer) { switch message { case .bind(let bind): buffer.writeInteger(message.id.rawValue) @@ -63,7 +63,7 @@ struct PSQLFrontendMessageEncoder: MessageToByteEncoder { } private func encode( - messageID: PSQLFrontendMessage.ID, + messageID: PostgresFrontendMessage.ID, payload: Payload, into buffer: inout ByteBuffer) { diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift similarity index 97% rename from Sources/PostgresNIO/New/PSQLFrontendMessage.swift rename to Sources/PostgresNIO/New/PostgresFrontendMessage.swift index 1a3cb28d..2017cd1a 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift @@ -4,7 +4,7 @@ import NIOCore /// /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) -enum PSQLFrontendMessage: Equatable { +enum PostgresFrontendMessage: Equatable { case bind(Bind) case cancel(Cancel) case close(Close) @@ -92,7 +92,7 @@ enum PSQLFrontendMessage: Equatable { } } -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { var id: ID { switch self { diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index c639f4b2..047a2968 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -1,7 +1,7 @@ @testable import PostgresNIO struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { - typealias InboundOut = PSQLFrontendMessage + typealias InboundOut = PostgresFrontendMessage private(set) var isInStartup: Bool @@ -9,7 +9,7 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { self.isInStartup = true } - mutating func decode(buffer: inout ByteBuffer) throws -> PSQLFrontendMessage? { + mutating func decode(buffer: inout ByteBuffer) throws -> PostgresFrontendMessage? { // make sure we have at least one byte to read guard buffer.readableBytes > 0 else { return nil @@ -58,14 +58,14 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { } } - let parameters = PSQLFrontendMessage.Startup.Parameters( + let parameters = PostgresFrontendMessage.Startup.Parameters( user: user!, database: database, options: options, replication: .false ) - let startup = PSQLFrontendMessage.Startup( + let startup = PostgresFrontendMessage.Startup( protocolVersion: 0x00_03_00_00, parameters: parameters ) @@ -95,7 +95,7 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { } // 2. make sure we have a known message identifier - guard let messageID = PSQLFrontendMessage.ID(rawValue: idByte) else { + guard let messageID = PostgresFrontendMessage.ID(rawValue: idByte) else { throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer) } @@ -106,7 +106,7 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { // move reader index forward by five bytes slice.moveReaderIndex(forwardBy: 5) - return try PSQLFrontendMessage.decode(from: &slice, for: messageID) + return try PostgresFrontendMessage.decode(from: &slice, for: messageID) } catch let error as PSQLPartialDecodingError { throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer) } catch { @@ -114,14 +114,14 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { } } - mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PSQLFrontendMessage? { + mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PostgresFrontendMessage? { try self.decode(buffer: &buffer) } } -extension PSQLFrontendMessage { +extension PostgresFrontendMessage { - static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PSQLFrontendMessage { + static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PostgresFrontendMessage { switch messageID { case .bind: preconditionFailure("TODO: Unimplemented") diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 285d00ca..5d63277d 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -10,12 +10,12 @@ class BindTests: XCTestCase { XCTAssertNoThrow(try bindings.append("Hello", context: .default)) XCTAssertNoThrow(try bindings.append("World", context: .default)) var byteBuffer = ByteBuffer() - let bind = PSQLFrontendMessage.Bind(portalName: "", preparedStatementName: "", bind: bindings) - let message = PSQLFrontendMessage.bind(bind) + let bind = PostgresFrontendMessage.Bind(portalName: "", preparedStatementName: "", bind: bindings) + let message = PostgresFrontendMessage.bind(bind) encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 37) - XCTAssertEqual(PSQLFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift index a1626538..c42f1999 100644 --- a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -7,8 +7,8 @@ class CancelTests: XCTestCase { func testEncodeCancel() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let cancel = PSQLFrontendMessage.Cancel(processID: 1234, secretKey: 4567) - let message = PSQLFrontendMessage.cancel(cancel) + let cancel = PostgresFrontendMessage.Cancel(processID: 1234, secretKey: 4567) + let message = PostgresFrontendMessage.cancel(cancel) encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 16) diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index d9edf95b..f6a0237b 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -7,11 +7,11 @@ class CloseTests: XCTestCase { func testEncodeClosePortal() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let message = PSQLFrontendMessage.close(.portal("Hello")) + let message = PostgresFrontendMessage.close(.portal("Hello")) encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 12) - XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) @@ -21,11 +21,11 @@ class CloseTests: XCTestCase { func testEncodeCloseUnnamedStatement() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let message = PSQLFrontendMessage.close(.preparedStatement("")) + let message = PostgresFrontendMessage.close(.preparedStatement("")) encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 7) - XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index 752a3d0f..df26f3d7 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -7,11 +7,11 @@ class DescribeTests: XCTestCase { func testEncodeDescribePortal() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let message = PSQLFrontendMessage.describe(.portal("Hello")) + let message = PostgresFrontendMessage.describe(.portal("Hello")) encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 12) - XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) @@ -21,11 +21,11 @@ class DescribeTests: XCTestCase { func testEncodeDescribeUnnamedStatement() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let message = PSQLFrontendMessage.describe(.preparedStatement("")) + let message = PostgresFrontendMessage.describe(.preparedStatement("")) encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 7) - XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual("", byteBuffer.readNullTerminatedString()) diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 9fdf06a7..dc5e2767 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -7,11 +7,11 @@ class ExecuteTests: XCTestCase { func testEncodeExecute() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let message = PSQLFrontendMessage.execute(.init(portalName: "", maxNumberOfRows: 0)) + let message = PostgresFrontendMessage.execute(.init(portalName: "", maxNumberOfRows: 0)) encoder.encode(data: message, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) - XCTAssertEqual(PSQLFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual(0, byteBuffer.readInteger(as: Int32.self)) diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 64654153..3d562473 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -7,11 +7,11 @@ class ParseTests: XCTestCase { func testEncode() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let parse = PSQLFrontendMessage.Parse( + let parse = PostgresFrontendMessage.Parse( preparedStatementName: "test", query: "SELECT version()", parameters: [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray]) - let message = PSQLFrontendMessage.parse(parse) + let message = PostgresFrontendMessage.parse(parse) encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 + (parse.preparedStatementName.count + 1) + (parse.query.count + 1) + 2 + parse.parameters.count * 4 @@ -22,7 +22,7 @@ class ParseTests: XCTestCase { // + 1 query () XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.parse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.parse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName) XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query) diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 492d2723..7572d382 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -8,13 +8,13 @@ class PasswordTests: XCTestCase { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() // md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4 - let message = PSQLFrontendMessage.password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a")) + let message = PostgresFrontendMessage.password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a")) encoder.encode(data: message, out: &byteBuffer) let expectedLength = 41 // 1 (id) + 4 (length) + 35 (string) + 1 (null termination) XCTAssertEqual(byteBuffer.readableBytes, expectedLength) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.password.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.password.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(expectedLength - 1)) // length XCTAssertEqual(byteBuffer.readNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") } diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index 8ad83134..08b3097d 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -7,9 +7,9 @@ class SASLInitialResponseTests: XCTestCase { func testEncodeWithData() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let sasl = PSQLFrontendMessage.SASLInitialResponse( + let sasl = PostgresFrontendMessage.SASLInitialResponse( saslMechanism: "hello", initialData: [0, 1, 2, 3, 4, 5, 6, 7]) - let message = PSQLFrontendMessage.saslInitialResponse(sasl) + let message = PostgresFrontendMessage.saslInitialResponse(sasl) encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 + (sasl.saslMechanism.count + 1) + 4 + sasl.initialData.count @@ -21,7 +21,7 @@ class SASLInitialResponseTests: XCTestCase { // + 8 initialData XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(sasl.initialData.count)) @@ -32,9 +32,9 @@ class SASLInitialResponseTests: XCTestCase { func testEncodeWithoutData() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let sasl = PSQLFrontendMessage.SASLInitialResponse( + let sasl = PostgresFrontendMessage.SASLInitialResponse( saslMechanism: "hello", initialData: []) - let message = PSQLFrontendMessage.saslInitialResponse(sasl) + let message = PostgresFrontendMessage.saslInitialResponse(sasl) encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 + (sasl.saslMechanism.count + 1) + 4 + sasl.initialData.count @@ -46,7 +46,7 @@ class SASLInitialResponseTests: XCTestCase { // + 0 initialData XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslInitialResponse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index 2b528ff4..e148420f 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -7,14 +7,14 @@ class SASLResponseTests: XCTestCase { func testEncodeWithData() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let sasl = PSQLFrontendMessage.SASLResponse(data: [0, 1, 2, 3, 4, 5, 6, 7]) - let message = PSQLFrontendMessage.saslResponse(sasl) + let sasl = PostgresFrontendMessage.SASLResponse(data: [0, 1, 2, 3, 4, 5, 6, 7]) + let message = PostgresFrontendMessage.saslResponse(sasl) encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 + (sasl.data.count) XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslResponse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readBytes(length: sasl.data.count), sasl.data) XCTAssertEqual(byteBuffer.readableBytes, 0) @@ -23,14 +23,14 @@ class SASLResponseTests: XCTestCase { func testEncodeWithoutData() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let sasl = PSQLFrontendMessage.SASLResponse(data: []) - let message = PSQLFrontendMessage.saslResponse(sasl) + let sasl = PostgresFrontendMessage.SASLResponse(data: []) + let message = PostgresFrontendMessage.saslResponse(sasl) encoder.encode(data: message, out: &byteBuffer) let length: Int = 1 + 4 XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PSQLFrontendMessage.ID.saslResponse.rawValue) + XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift index 1cc72bb1..9a973f2b 100644 --- a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift @@ -7,8 +7,8 @@ class SSLRequestTests: XCTestCase { func testSSLRequest() { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let request = PSQLFrontendMessage.SSLRequest() - let message = PSQLFrontendMessage.sslRequest(request) + let request = PostgresFrontendMessage.SSLRequest() + let message = PostgresFrontendMessage.sslRequest(request) encoder.encode(data: message, out: &byteBuffer) let byteBufferLength = Int32(byteBuffer.readableBytes) diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 913d02ef..08a9ee21 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -8,22 +8,22 @@ class StartupTests: XCTestCase { let encoder = PSQLFrontendMessageEncoder() var byteBuffer = ByteBuffer() - let replicationValues: [PSQLFrontendMessage.Startup.Parameters.Replication] = [ + let replicationValues: [PostgresFrontendMessage.Startup.Parameters.Replication] = [ .`true`, .`false`, .database ] for replication in replicationValues { - let parameters = PSQLFrontendMessage.Startup.Parameters( + let parameters = PostgresFrontendMessage.Startup.Parameters( user: "test", database: "abc123", options: "some options", replication: replication ) - let startup = PSQLFrontendMessage.Startup.versionThree(parameters: parameters) - let message = PSQLFrontendMessage.startup(startup) + let startup = PostgresFrontendMessage.Startup.versionThree(parameters: parameters) + let message = PostgresFrontendMessage.startup(startup) encoder.encode(data: message, out: &byteBuffer) let byteBufferLength = Int32(byteBuffer.readableBytes) @@ -46,7 +46,7 @@ class StartupTests: XCTestCase { } } -extension PSQLFrontendMessage.Startup.Parameters.Replication { +extension PostgresFrontendMessage.Startup.Parameters.Replication { var stringValue: String { switch self { case .true: diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index 6927c50e..8085c326 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -19,9 +19,9 @@ class PSQLChannelHandlerTests: XCTestCase { ]) defer { XCTAssertNoThrow(try embedded.finish()) } - var maybeMessage: PSQLFrontendMessage? + var maybeMessage: PostgresFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) - XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PSQLFrontendMessage.self)) + XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) guard case .startup(let startup) = maybeMessage else { return XCTFail("Unexpected message") } @@ -49,9 +49,9 @@ class PSQLChannelHandlerTests: XCTestCase { handler ]) - var maybeMessage: PSQLFrontendMessage? + var maybeMessage: PostgresFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) - XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PSQLFrontendMessage.self)) + XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) guard case .sslRequest(let request) = maybeMessage else { return XCTFail("Unexpected message") } @@ -67,8 +67,8 @@ class PSQLChannelHandlerTests: XCTestCase { embedded.pipeline.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: "")) // startup message should be issued - var maybeStartupMessage: PSQLFrontendMessage? - XCTAssertNoThrow(maybeStartupMessage = try embedded.readOutbound(as: PSQLFrontendMessage.self)) + var maybeStartupMessage: PostgresFrontendMessage? + XCTAssertNoThrow(maybeStartupMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) guard case .startup(let startupMessage) = maybeStartupMessage else { return XCTFail("Unexpected message") } @@ -98,7 +98,7 @@ class PSQLChannelHandlerTests: XCTestCase { XCTAssertTrue(embedded.isActive) // read the ssl request message - XCTAssertEqual(try embedded.readOutbound(as: PSQLFrontendMessage.self), .sslRequest(.init())) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .sslRequest(.init())) XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.sslUnsupported)) // the event handler should have seen an error @@ -126,12 +126,12 @@ class PSQLChannelHandlerTests: XCTestCase { ]) embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) - XCTAssertEqual(try embedded.readOutbound(as: PSQLFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.authentication(.md5(salt: (0,1,2,3))))) - var message: PSQLFrontendMessage? - XCTAssertNoThrow(message = try embedded.readOutbound(as: PSQLFrontendMessage.self)) + var message: PostgresFrontendMessage? + XCTAssertNoThrow(message = try embedded.readOutbound(as: PostgresFrontendMessage.self)) XCTAssertEqual(message, .password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a"))) } @@ -155,12 +155,12 @@ class PSQLChannelHandlerTests: XCTestCase { ]) embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) - XCTAssertEqual(try embedded.readOutbound(as: PSQLFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.authentication(.plaintext))) - var message: PSQLFrontendMessage? - XCTAssertNoThrow(message = try embedded.readOutbound(as: PSQLFrontendMessage.self)) + var message: PostgresFrontendMessage? + XCTAssertNoThrow(message = try embedded.readOutbound(as: PostgresFrontendMessage.self)) XCTAssertEqual(message, .password(.init(value: password))) } diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift index 7a8d56eb..59b69bae 100644 --- a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -7,17 +7,17 @@ class PSQLFrontendMessageTests: XCTestCase { // MARK: ID func testMessageIDs() { - XCTAssertEqual(PSQLFrontendMessage.ID.bind.rawValue, UInt8(ascii: "B")) - XCTAssertEqual(PSQLFrontendMessage.ID.close.rawValue, UInt8(ascii: "C")) - XCTAssertEqual(PSQLFrontendMessage.ID.describe.rawValue, UInt8(ascii: "D")) - XCTAssertEqual(PSQLFrontendMessage.ID.execute.rawValue, UInt8(ascii: "E")) - XCTAssertEqual(PSQLFrontendMessage.ID.flush.rawValue, UInt8(ascii: "H")) - XCTAssertEqual(PSQLFrontendMessage.ID.parse.rawValue, UInt8(ascii: "P")) - XCTAssertEqual(PSQLFrontendMessage.ID.password.rawValue, UInt8(ascii: "p")) - XCTAssertEqual(PSQLFrontendMessage.ID.saslInitialResponse.rawValue, UInt8(ascii: "p")) - XCTAssertEqual(PSQLFrontendMessage.ID.saslResponse.rawValue, UInt8(ascii: "p")) - XCTAssertEqual(PSQLFrontendMessage.ID.sync.rawValue, UInt8(ascii: "S")) - XCTAssertEqual(PSQLFrontendMessage.ID.terminate.rawValue, UInt8(ascii: "X")) + XCTAssertEqual(PostgresFrontendMessage.ID.bind.rawValue, UInt8(ascii: "B")) + XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, UInt8(ascii: "C")) + XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, UInt8(ascii: "D")) + XCTAssertEqual(PostgresFrontendMessage.ID.execute.rawValue, UInt8(ascii: "E")) + XCTAssertEqual(PostgresFrontendMessage.ID.flush.rawValue, UInt8(ascii: "H")) + XCTAssertEqual(PostgresFrontendMessage.ID.parse.rawValue, UInt8(ascii: "P")) + XCTAssertEqual(PostgresFrontendMessage.ID.password.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PostgresFrontendMessage.ID.saslInitialResponse.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PostgresFrontendMessage.ID.saslResponse.rawValue, UInt8(ascii: "p")) + XCTAssertEqual(PostgresFrontendMessage.ID.sync.rawValue, UInt8(ascii: "S")) + XCTAssertEqual(PostgresFrontendMessage.ID.terminate.rawValue, UInt8(ascii: "X")) } // MARK: Encoder @@ -28,7 +28,7 @@ class PSQLFrontendMessageTests: XCTestCase { encoder.encode(data: .flush, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 5) - XCTAssertEqual(PSQLFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } @@ -38,7 +38,7 @@ class PSQLFrontendMessageTests: XCTestCase { encoder.encode(data: .sync, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 5) - XCTAssertEqual(PSQLFrontendMessage.ID.sync.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.sync.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } @@ -48,7 +48,7 @@ class PSQLFrontendMessageTests: XCTestCase { encoder.encode(data: .terminate, out: &byteBuffer) XCTAssertEqual(byteBuffer.readableBytes, 5) - XCTAssertEqual(PSQLFrontendMessage.ID.terminate.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(PostgresFrontendMessage.ID.terminate.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } From 43742ef1d66c4e1ba160d76aa8c11129f10b0972 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 2 Mar 2022 08:49:26 +0100 Subject: [PATCH 074/292] Rename PSQLBackendMessage to PostgresBackendMessage (#238) --- .../Connection/PostgresConnection.swift | 2 +- .../AuthenticationStateMachine.swift | 4 +- .../CloseStateMachine.swift | 2 +- .../ConnectionStateMachine.swift | 20 +-- .../ExtendedQueryStateMachine.swift | 6 +- .../PrepareStatementStateMachine.swift | 4 +- .../New/Extensions/ByteBuffer+PSQL.swift | 2 +- .../New/Messages/Authentication.swift | 6 +- .../New/Messages/BackendKeyData.swift | 4 +- .../PostgresNIO/New/Messages/DataRow.swift | 2 +- .../New/Messages/ErrorResponse.swift | 24 +-- .../New/Messages/NotificationResponse.swift | 4 +- .../New/Messages/ParameterDescription.swift | 2 +- .../New/Messages/ParameterStatus.swift | 4 +- .../New/Messages/ReadyForQuery.swift | 4 +- .../New/Messages/RowDescription.swift | 2 +- .../New/PSQLBackendMessageDecoder.swift | 10 +- .../PostgresNIO/New/PSQLChannelHandler.swift | 2 +- Sources/PostgresNIO/New/PSQLError.swift | 8 +- Sources/PostgresNIO/New/PSQLRowStream.swift | 2 +- ...age.swift => PostgresBackendMessage.swift} | 14 +- .../AuthenticationStateMachineTests.swift | 8 +- .../ConnectionStateMachineTests.swift | 2 +- .../New/Extensions/ByteBuffer+Utils.swift | 4 +- .../ConnectionAction+TestUtils.swift | 4 +- .../PSQLBackendMessage+Equatable.swift | 2 +- .../PSQLBackendMessageEncoder.swift | 24 +-- .../New/Messages/AuthenticationTests.swift | 2 +- .../New/Messages/BackendKeyDataTests.swift | 4 +- .../New/Messages/DataRowTests.swift | 2 +- .../New/Messages/ErrorResponseTests.swift | 4 +- .../Messages/NotificationResponseTests.swift | 2 +- .../Messages/ParameterDescriptionTests.swift | 2 +- .../New/Messages/ParameterStatusTests.swift | 2 +- .../New/Messages/ReadyForQueryTests.swift | 10 +- .../New/Messages/RowDescriptionTests.swift | 2 +- .../New/PSQLBackendMessageTests.swift | 164 +++++++++--------- .../New/PSQLChannelHandlerTests.swift | 14 +- 38 files changed, 190 insertions(+), 190 deletions(-) rename Sources/PostgresNIO/New/{PSQLBackendMessage.swift => PostgresBackendMessage.swift} (96%) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index d6a26e5e..be7e6c97 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -263,7 +263,7 @@ extension PostgresConnection { } extension PostgresConnection: PSQLChannelHandlerNotificationDelegate { - func notificationReceived(_ notification: PSQLBackendMessage.NotificationResponse) { + func notificationReceived(_ notification: PostgresBackendMessage.NotificationResponse) { self.underlying.eventLoop.assertInEventLoop() guard let listeners = self.notificationListeners[notification.channel] else { diff --git a/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift index 5848288d..859a4d4b 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift @@ -42,7 +42,7 @@ struct AuthenticationStateMachine { return .sendStartupMessage(self.authContext) } - mutating func authenticationMessageReceived(_ message: PSQLBackendMessage.Authentication) -> Action { + mutating func authenticationMessageReceived(_ message: PostgresBackendMessage.Authentication) -> Action { switch self.state { case .startupMessageSent: switch message { @@ -156,7 +156,7 @@ struct AuthenticationStateMachine { } } - mutating func errorReceived(_ message: PSQLBackendMessage.ErrorResponse) -> Action { + mutating func errorReceived(_ message: PostgresBackendMessage.ErrorResponse) -> Action { return self.setAndFireError(.server(message)) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift index 0dccd10d..791cebdd 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift @@ -43,7 +43,7 @@ struct CloseStateMachine { return .succeedClose(closeContext) } - mutating func errorReceived(_ errorMessage: PSQLBackendMessage.ErrorResponse) -> Action { + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { let error = PSQLError.server(errorMessage) switch self.state { case .initialized: diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 82db845f..4a1a2813 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -2,7 +2,7 @@ import NIOCore struct ConnectionStateMachine { - typealias TransactionState = PSQLBackendMessage.TransactionState + typealias TransactionState = PostgresBackendMessage.TransactionState struct ConnectionContext { let processID: Int32 @@ -71,7 +71,7 @@ struct ConnectionStateMachine { case sendSSLRequest case establishSSLConnection case provideAuthenticationContext - case forwardNotificationToListeners(PSQLBackendMessage.NotificationResponse) + case forwardNotificationToListeners(PostgresBackendMessage.NotificationResponse) case fireEventReadyForQuery case fireChannelInactive /// Close the connection by sending a `Terminate` message and then closing the connection. This is for clean shutdowns. @@ -319,7 +319,7 @@ struct ConnectionStateMachine { } } - mutating func authenticationMessageReceived(_ message: PSQLBackendMessage.Authentication) -> ConnectionAction { + mutating func authenticationMessageReceived(_ message: PostgresBackendMessage.Authentication) -> ConnectionAction { guard case .authenticating(var authState) = self.state else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.authentication(message))) } @@ -331,7 +331,7 @@ struct ConnectionStateMachine { } } - mutating func backendKeyDataReceived(_ keyData: PSQLBackendMessage.BackendKeyData) -> ConnectionAction { + mutating func backendKeyDataReceived(_ keyData: PostgresBackendMessage.BackendKeyData) -> ConnectionAction { guard case .authenticated(_, let parameters) = self.state else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.backendKeyData(keyData))) } @@ -344,7 +344,7 @@ struct ConnectionStateMachine { return .wait } - mutating func parameterStatusReceived(_ status: PSQLBackendMessage.ParameterStatus) -> ConnectionAction { + mutating func parameterStatusReceived(_ status: PostgresBackendMessage.ParameterStatus) -> ConnectionAction { switch self.state { case .sslRequestSent, .sslNegotiated, @@ -394,7 +394,7 @@ struct ConnectionStateMachine { } } - mutating func errorReceived(_ errorMessage: PSQLBackendMessage.ErrorResponse) -> ConnectionAction { + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> ConnectionAction { switch self.state { case .sslRequestSent, .sslNegotiated, @@ -508,7 +508,7 @@ struct ConnectionStateMachine { } } - mutating func noticeReceived(_ notice: PSQLBackendMessage.NoticeResponse) -> ConnectionAction { + mutating func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) -> ConnectionAction { switch self.state { case .extendedQuery(var extendedQuery, let connectionContext): return self.avoidingStateMachineCoW { machine -> ConnectionAction in @@ -521,11 +521,11 @@ struct ConnectionStateMachine { } } - mutating func notificationReceived(_ notification: PSQLBackendMessage.NotificationResponse) -> ConnectionAction { + mutating func notificationReceived(_ notification: PostgresBackendMessage.NotificationResponse) -> ConnectionAction { return .forwardNotificationToListeners(notification) } - mutating func readyForQueryReceived(_ transactionState: PSQLBackendMessage.TransactionState) -> ConnectionAction { + mutating func readyForQueryReceived(_ transactionState: PostgresBackendMessage.TransactionState) -> ConnectionAction { switch self.state { case .authenticated(let backendKeyData, let parameters): guard let keyData = backendKeyData else { @@ -715,7 +715,7 @@ struct ConnectionStateMachine { } } - mutating func parameterDescriptionReceived(_ description: PSQLBackendMessage.ParameterDescription) -> ConnectionAction { + mutating func parameterDescriptionReceived(_ description: PostgresBackendMessage.ParameterDescription) -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: return self.avoidingStateMachineCoW { machine -> ConnectionAction in diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index c778477a..333742bb 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -83,7 +83,7 @@ struct ExtendedQueryStateMachine { } } - mutating func parameterDescriptionReceived(_ parameterDescription: PSQLBackendMessage.ParameterDescription) -> Action { + mutating func parameterDescriptionReceived(_ parameterDescription: PostgresBackendMessage.ParameterDescription) -> Action { guard case .parseCompleteReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.parameterDescription(parameterDescription))) } @@ -217,7 +217,7 @@ struct ExtendedQueryStateMachine { preconditionFailure("Unimplemented") } - mutating func errorReceived(_ errorMessage: PSQLBackendMessage.ErrorResponse) -> Action { + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { let error = PSQLError.server(errorMessage) switch self.state { case .initialized: @@ -244,7 +244,7 @@ struct ExtendedQueryStateMachine { } } - mutating func noticeReceived(_ notice: PSQLBackendMessage.NoticeResponse) -> Action { + mutating func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) -> Action { //self.queryObject.noticeReceived(notice) return .wait } diff --git a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift index 947c8f97..5b65fc90 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift @@ -54,7 +54,7 @@ struct PrepareStatementStateMachine { return .wait } - mutating func parameterDescriptionReceived(_ parameterDescription: PSQLBackendMessage.ParameterDescription) -> Action { + mutating func parameterDescriptionReceived(_ parameterDescription: PostgresBackendMessage.ParameterDescription) -> Action { guard case .parseCompleteReceived(let createContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.parameterDescription(parameterDescription))) } @@ -81,7 +81,7 @@ struct PrepareStatementStateMachine { return .succeedPreparedStatementCreation(queryContext, with: rowDescription) } - mutating func errorReceived(_ errorMessage: PSQLBackendMessage.ErrorResponse) -> Action { + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { let error = PSQLError.server(errorMessage) switch self.state { case .initialized: diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 6793d50e..f226bd7b 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -2,7 +2,7 @@ import NIOCore internal extension ByteBuffer { - mutating func psqlWriteBackendMessageID(_ messageID: PSQLBackendMessage.ID) { + mutating func psqlWriteBackendMessageID(_ messageID: PostgresBackendMessage.ID) { self.writeInteger(messageID.rawValue) } diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift index 54d7c6ad..bd0d2e57 100644 --- a/Sources/PostgresNIO/New/Messages/Authentication.swift +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLBackendMessage { +extension PostgresBackendMessage { enum Authentication: PayloadDecodable { case ok @@ -61,7 +61,7 @@ extension PSQLBackendMessage { } } -extension PSQLBackendMessage.Authentication: Equatable { +extension PostgresBackendMessage.Authentication: Equatable { static func ==(lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { case (.ok, .ok): @@ -92,7 +92,7 @@ extension PSQLBackendMessage.Authentication: Equatable { } } -extension PSQLBackendMessage.Authentication: CustomDebugStringConvertible { +extension PostgresBackendMessage.Authentication: CustomDebugStringConvertible { var debugDescription: String { switch self { case .ok: diff --git a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift index 2d6a23a4..498c5110 100644 --- a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift +++ b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLBackendMessage { +extension PostgresBackendMessage { struct BackendKeyData: PayloadDecodable, Equatable { let processID: Int32 @@ -16,7 +16,7 @@ extension PSQLBackendMessage { } } -extension PSQLBackendMessage.BackendKeyData: CustomDebugStringConvertible { +extension PostgresBackendMessage.BackendKeyData: CustomDebugStringConvertible { var debugDescription: String { "processID: \(processID), secretKey: \(secretKey)" } diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index 31148c20..b49c9eeb 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -8,7 +8,7 @@ import NIOCore /// enclosing type, the enclosing type must be @usableFromInline as well. /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler -struct DataRow: PSQLBackendMessage.PayloadDecodable, Equatable { +struct DataRow: PostgresBackendMessage.PayloadDecodable, Equatable { var columnCount: Int16 diff --git a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift index 254cdf0f..818c1ebf 100644 --- a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift +++ b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLBackendMessage { +extension PostgresBackendMessage { enum Field: UInt8, Hashable { /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), @@ -81,40 +81,40 @@ extension PSQLBackendMessage { } struct ErrorResponse: PSQLMessageNotice, PayloadDecodable, Equatable { - let fields: [PSQLBackendMessage.Field: String] + let fields: [PostgresBackendMessage.Field: String] - init(fields: [PSQLBackendMessage.Field: String]) { + init(fields: [PostgresBackendMessage.Field: String]) { self.fields = fields } } struct NoticeResponse: PSQLMessageNotice, PayloadDecodable, Equatable { - let fields: [PSQLBackendMessage.Field: String] + let fields: [PostgresBackendMessage.Field: String] - init(fields: [PSQLBackendMessage.Field: String]) { + init(fields: [PostgresBackendMessage.Field: String]) { self.fields = fields } } } protocol PSQLMessageNotice { - var fields: [PSQLBackendMessage.Field: String] { get } + var fields: [PostgresBackendMessage.Field: String] { get } - init(fields: [PSQLBackendMessage.Field: String]) + init(fields: [PostgresBackendMessage.Field: String]) } -extension PSQLBackendMessage.PayloadDecodable where Self: PSQLMessageNotice { +extension PostgresBackendMessage.PayloadDecodable where Self: PSQLMessageNotice { static func decode(from buffer: inout ByteBuffer) throws -> Self { - var fields: [PSQLBackendMessage.Field: String] = [:] + var fields: [PostgresBackendMessage.Field: String] = [:] while let id = buffer.readInteger(as: UInt8.self) { if id == 0 { break } - guard let field = PSQLBackendMessage.Field(rawValue: id) else { + guard let field = PostgresBackendMessage.Field(rawValue: id) else { throw PSQLPartialDecodingError.valueNotRawRepresentable( value: id, - asType: PSQLBackendMessage.Field.self) + asType: PostgresBackendMessage.Field.self) } guard let string = buffer.readNullTerminatedString() else { @@ -126,7 +126,7 @@ extension PSQLBackendMessage.PayloadDecodable where Self: PSQLMessageNotice { } } -extension PSQLBackendMessage.Field: CustomStringConvertible { +extension PostgresBackendMessage.Field: CustomStringConvertible { var description: String { switch self { diff --git a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift index dd5c0cf2..5cd9422e 100644 --- a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift +++ b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift @@ -1,13 +1,13 @@ import NIOCore -extension PSQLBackendMessage { +extension PostgresBackendMessage { struct NotificationResponse: PayloadDecodable, Equatable { let backendPID: Int32 let channel: String let payload: String - static func decode(from buffer: inout ByteBuffer) throws -> PSQLBackendMessage.NotificationResponse { + static func decode(from buffer: inout ByteBuffer) throws -> PostgresBackendMessage.NotificationResponse { let backendPID = try buffer.throwingReadInteger(as: Int32.self) guard let channel = buffer.readNullTerminatedString() else { diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index bd468c44..0d519583 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLBackendMessage { +extension PostgresBackendMessage { struct ParameterDescription: PayloadDecodable, Equatable { /// Specifies the object ID of the parameter data type. diff --git a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift index 89dd1d6d..4ffcbe12 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLBackendMessage { +extension PostgresBackendMessage { struct ParameterStatus: PayloadDecodable, Equatable { /// The name of the run-time parameter being reported. @@ -23,7 +23,7 @@ extension PSQLBackendMessage { } } -extension PSQLBackendMessage.ParameterStatus: CustomDebugStringConvertible { +extension PostgresBackendMessage.ParameterStatus: CustomDebugStringConvertible { var debugDescription: String { "parameter: \(String(reflecting: self.parameter)), value: \(String(reflecting: self.value))" } diff --git a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift index b8fff2aa..a300f714 100644 --- a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift +++ b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift @@ -1,6 +1,6 @@ import NIOCore -extension PSQLBackendMessage { +extension PostgresBackendMessage { enum TransactionState: PayloadDecodable, RawRepresentable { typealias RawValue = UInt8 @@ -43,7 +43,7 @@ extension PSQLBackendMessage { } } -extension PSQLBackendMessage.TransactionState: CustomDebugStringConvertible { +extension PostgresBackendMessage.TransactionState: CustomDebugStringConvertible { var debugDescription: String { switch self { case .idle: diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index 9ca491db..de855e98 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -8,7 +8,7 @@ import NIOCore /// enclosing type, the enclosing type must be @usableFromInline as well. /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler. -struct RowDescription: PSQLBackendMessage.PayloadDecodable, Equatable { +struct RowDescription: PostgresBackendMessage.PayloadDecodable, Equatable { /// Specifies the object ID of the parameter data type. var columns: [Column] diff --git a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift index 47485a7b..9a3d6628 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift @@ -1,5 +1,5 @@ struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { - typealias InboundOut = PSQLBackendMessage + typealias InboundOut = PostgresBackendMessage private(set) var hasAlreadyReceivedBytes: Bool @@ -7,7 +7,7 @@ struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { self.hasAlreadyReceivedBytes = hasAlreadyReceivedBytes } - mutating func decode(buffer: inout ByteBuffer) throws -> PSQLBackendMessage? { + mutating func decode(buffer: inout ByteBuffer) throws -> PostgresBackendMessage? { if !self.hasAlreadyReceivedBytes { // We have not received any bytes yet! Let's peek at the first message id. If it @@ -51,7 +51,7 @@ struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { } // 2. make sure we have a known message identifier - guard let messageID = PSQLBackendMessage.ID(rawValue: idByte) else { + guard let messageID = PostgresBackendMessage.ID(rawValue: idByte) else { buffer.moveReaderIndex(to: startReaderIndex) let completeMessage = buffer.readSlice(length: Int(length) + 1)! throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessage) @@ -59,7 +59,7 @@ struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { // 3. decode the message do { - let result = try PSQLBackendMessage.decode(from: &message, for: messageID) + let result = try PostgresBackendMessage.decode(from: &message, for: messageID) if message.readableBytes > 0 { throw PSQLPartialDecodingError.expectedExactlyNRemainingBytes(0, actual: message.readableBytes) } @@ -73,7 +73,7 @@ struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { } } - mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PSQLBackendMessage? { + mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PostgresBackendMessage? { try self.decode(buffer: &buffer) } } diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index f0671bca..c39537d6 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -4,7 +4,7 @@ import Crypto import Logging protocol PSQLChannelHandlerNotificationDelegate: AnyObject { - func notificationReceived(_: PSQLBackendMessage.NotificationResponse) + func notificationReceived(_: PostgresBackendMessage.NotificationResponse) } final class PSQLChannelHandler: ChannelDuplexHandler { diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index cdcf86c2..022d6016 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -5,9 +5,9 @@ struct PSQLError: Error { enum Base { case sslUnsupported case failedToAddSSLHandler(underlying: Error) - case server(PSQLBackendMessage.ErrorResponse) + case server(PostgresBackendMessage.ErrorResponse) case decoding(PSQLDecodingError) - case unexpectedBackendMessage(PSQLBackendMessage) + case unexpectedBackendMessage(PostgresBackendMessage) case unsupportedAuthMechanism(PSQLAuthScheme) case authMechanismRequiresPassword case saslError(underlyingError: Error) @@ -35,7 +35,7 @@ struct PSQLError: Error { Self.init(.failedToAddSSLHandler(underlying: error)) } - static func server(_ message: PSQLBackendMessage.ErrorResponse) -> PSQLError { + static func server(_ message: PostgresBackendMessage.ErrorResponse) -> PSQLError { Self.init(.server(message)) } @@ -43,7 +43,7 @@ struct PSQLError: Error { Self.init(.decoding(error)) } - static func unexpectedBackendMessage(_ message: PSQLBackendMessage) -> PSQLError { + static func unexpectedBackendMessage(_ message: PostgresBackendMessage) -> PSQLError { Self.init(.unexpectedBackendMessage(message)) } diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 787c6cef..2d0ec455 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -255,7 +255,7 @@ final class PSQLRowStream { } } - internal func noticeReceived(_ notice: PSQLBackendMessage.NoticeResponse) { + internal func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) { self.logger.debug("Notice Received", metadata: [ .notice: "\(notice)" ]) diff --git a/Sources/PostgresNIO/New/PSQLBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift similarity index 96% rename from Sources/PostgresNIO/New/PSQLBackendMessage.swift rename to Sources/PostgresNIO/New/PostgresBackendMessage.swift index 77f7b78b..ecccd1e9 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -2,10 +2,10 @@ import NIOCore //import struct Foundation.Data -/// A protocol to implement for all associated value in the `PSQLBackendMessage` enum +/// A protocol to implement for all associated value in the `PostgresBackendMessage` enum protocol PSQLMessagePayloadDecodable { - /// Decodes the associated value for a `PSQLBackendMessage` from the given `ByteBuffer`. + /// Decodes the associated value for a `PostgresBackendMessage` from the given `ByteBuffer`. /// /// When the decoding is done all bytes in the given `ByteBuffer` must be consumed. /// `buffer.readableBytes` must be `0`. In case of an error a `PartialDecodingError` @@ -20,7 +20,7 @@ protocol PSQLMessagePayloadDecodable { /// /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) -enum PSQLBackendMessage { +enum PostgresBackendMessage { typealias PayloadDecodable = PSQLMessagePayloadDecodable @@ -45,7 +45,7 @@ enum PSQLBackendMessage { case sslUnsupported } -extension PSQLBackendMessage { +extension PostgresBackendMessage { enum ID: RawRepresentable, Equatable { typealias RawValue = UInt8 @@ -184,9 +184,9 @@ extension PSQLBackendMessage { } } -extension PSQLBackendMessage { +extension PostgresBackendMessage { - static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PSQLBackendMessage { + static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PostgresBackendMessage { switch messageID { case .authentication: return try .authentication(.decode(from: &buffer)) @@ -248,7 +248,7 @@ extension PSQLBackendMessage { } } -extension PSQLBackendMessage: CustomDebugStringConvertible { +extension PostgresBackendMessage: CustomDebugStringConvertible { var debugDescription: String { switch self { case .authentication(let authentication): diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index 2ed28c20..238f4884 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -53,7 +53,7 @@ class AuthenticationStateMachineTests: XCTestCase { XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) - let fields: [PSQLBackendMessage.Field: String] = [ + let fields: [PostgresBackendMessage.Field: String] = [ .message: "password authentication failed for user \"postgres\"", .severity: "FATAL", .sqlState: "28P01", @@ -69,7 +69,7 @@ class AuthenticationStateMachineTests: XCTestCase { // MARK: Test unsupported messages func testUnsupportedAuthMechanism() { - let unsupported: [(PSQLBackendMessage.Authentication, PSQLAuthScheme)] = [ + let unsupported: [(PostgresBackendMessage.Authentication, PSQLAuthScheme)] = [ (.kerberosV5, .kerberosV5), (.scmCredential, .scmCredential), (.gss, .gss), @@ -90,7 +90,7 @@ class AuthenticationStateMachineTests: XCTestCase { func testUnexpectedMessagesAfterStartUp() { var buffer = ByteBuffer() buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8]) - let unexpected: [PSQLBackendMessage.Authentication] = [ + let unexpected: [PostgresBackendMessage.Authentication] = [ .gssContinue(data: buffer), .saslContinue(data: buffer), .saslFinal(data: buffer) @@ -110,7 +110,7 @@ class AuthenticationStateMachineTests: XCTestCase { let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) var buffer = ByteBuffer() buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8]) - let unexpected: [PSQLBackendMessage.Authentication] = [ + let unexpected: [PostgresBackendMessage.Authentication] = [ .kerberosV5, .md5(salt: (0, 1, 2, 3)), .plaintext, diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 5b7ed388..4a63e31c 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -143,7 +143,7 @@ class ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) - let fields: [PSQLBackendMessage.Field: String] = [ + let fields: [PostgresBackendMessage.Field: String] = [ .message: "password authentication failed for user \"postgres\"", .severity: "FATAL", .sqlState: "28P01", diff --git a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift index 835965da..71994596 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift @@ -3,13 +3,13 @@ import NIOCore extension ByteBuffer { - static func backendMessage(id: PSQLBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows -> ByteBuffer { + static func backendMessage(id: PostgresBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows -> ByteBuffer { var byteBuffer = ByteBuffer() try byteBuffer.writeBackendMessage(id: id, payload) return byteBuffer } - mutating func writeBackendMessage(id: PSQLBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows { + mutating func writeBackendMessage(id: PostgresBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows { self.psqlWriteBackendMessageID(id) let lengthIndex = self.writerIndex self.writeInteger(Int32(0)) diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index 13323e76..448183b5 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -67,12 +67,12 @@ extension ConnectionStateMachine.ConnectionAction.CleanUpContext: Equatable { } extension ConnectionStateMachine { - static func readyForQuery(transactionState: PSQLBackendMessage.TransactionState = .idle) -> Self { + static func readyForQuery(transactionState: PostgresBackendMessage.TransactionState = .idle) -> Self { let connectionContext = Self.createConnectionContext(transactionState: transactionState) return ConnectionStateMachine(.readyForQuery(connectionContext)) } - static func createConnectionContext(transactionState: PSQLBackendMessage.TransactionState = .idle) -> ConnectionContext { + static func createConnectionContext(transactionState: PostgresBackendMessage.TransactionState = .idle) -> ConnectionContext { let paramaters = [ "DateStyle": "ISO, MDY", "application_name": "", diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift index 436c7aa9..c459ffeb 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift @@ -1,6 +1,6 @@ @testable import PostgresNIO -extension PSQLBackendMessage: Equatable { +extension PostgresBackendMessage: Equatable { public static func ==(lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 8ef8033c..eea7dec3 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -2,14 +2,14 @@ import NIOCore @testable import PostgresNIO struct PSQLBackendMessageEncoder: MessageToByteEncoder { - typealias OutboundIn = PSQLBackendMessage + typealias OutboundIn = PostgresBackendMessage /// Called once there is data to encode. /// /// - parameters: /// - data: The data to encode into a `ByteBuffer`. /// - out: The `ByteBuffer` into which we want to encode. - func encode(data message: PSQLBackendMessage, out buffer: inout ByteBuffer) throws { + func encode(data message: PostgresBackendMessage, out buffer: inout ByteBuffer) throws { switch message { case .authentication(let authentication): self.encode(messageID: message.id, payload: authentication, into: &buffer) @@ -73,7 +73,7 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { } private func encode( - messageID: PSQLBackendMessage.ID, + messageID: PostgresBackendMessage.ID, payload: Payload, into buffer: inout ByteBuffer) { @@ -86,7 +86,7 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { } } -extension PSQLBackendMessage { +extension PostgresBackendMessage { var id: ID { switch self { case .authentication: @@ -130,7 +130,7 @@ extension PSQLBackendMessage { } } -extension PSQLBackendMessage.Authentication: PSQLMessagePayloadEncodable { +extension PostgresBackendMessage.Authentication: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { switch self { @@ -181,7 +181,7 @@ extension PSQLBackendMessage.Authentication: PSQLMessagePayloadEncodable { } -extension PSQLBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { +extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.processID) buffer.writeInteger(self.secretKey) @@ -195,7 +195,7 @@ extension DataRow: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.ErrorResponse: PSQLMessagePayloadEncodable { +extension PostgresBackendMessage.ErrorResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { for (key, value) in self.fields { buffer.writeInteger(key.rawValue, as: UInt8.self) @@ -205,7 +205,7 @@ extension PSQLBackendMessage.ErrorResponse: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { +extension PostgresBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { for (key, value) in self.fields { buffer.writeInteger(key.rawValue, as: UInt8.self) @@ -215,7 +215,7 @@ extension PSQLBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.NotificationResponse: PSQLMessagePayloadEncodable { +extension PostgresBackendMessage.NotificationResponse: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.backendPID) buffer.writeNullTerminatedString(self.channel) @@ -223,7 +223,7 @@ extension PSQLBackendMessage.NotificationResponse: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.ParameterDescription: PSQLMessagePayloadEncodable { +extension PostgresBackendMessage.ParameterDescription: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(Int16(self.dataTypes.count)) @@ -233,14 +233,14 @@ extension PSQLBackendMessage.ParameterDescription: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.ParameterStatus: PSQLMessagePayloadEncodable { +extension PostgresBackendMessage.ParameterStatus: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeNullTerminatedString(self.parameter) buffer.writeNullTerminatedString(self.value) } } -extension PSQLBackendMessage.TransactionState: PSQLMessagePayloadEncodable { +extension PostgresBackendMessage.TransactionState: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.rawValue) } diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 52e63b2e..85a4314f 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -6,7 +6,7 @@ import NIOTestUtils class AuthenticationTests: XCTestCase { func testDecodeAuthentication() { - var expected = [PSQLBackendMessage]() + var expected = [PostgresBackendMessage]() var buffer = ByteBuffer() let encoder = PSQLBackendMessageEncoder() diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift index 5715c61c..2db8493b 100644 --- a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -11,7 +11,7 @@ class BackendKeyDataTests: XCTestCase { } let expectedInOuts = [ - (buffer, [PSQLBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), + (buffer, [PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), ] XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( @@ -27,7 +27,7 @@ class BackendKeyDataTests: XCTestCase { buffer.writeInteger(Int32(4567)) let expected = [ - (buffer, [PSQLBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), + (buffer, [PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), ] XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index 643c8a28..660baa92 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -23,7 +23,7 @@ class DataRowTests: XCTestCase { let rowSlice = buffer.getSlice(at: 7, length: buffer.readableBytes - 7)! let expectedInOuts = [ - (buffer, [PSQLBackendMessage.dataRow(.init(columnCount: 3, bytes: rowSlice))]), + (buffer, [PostgresBackendMessage.dataRow(.init(columnCount: 3, bytes: rowSlice))]), ] XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( diff --git a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift index bbc945e4..038ec34c 100644 --- a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift @@ -6,7 +6,7 @@ import NIOTestUtils class ErrorResponseTests: XCTestCase { func testDecode() { - let fields: [PSQLBackendMessage.Field : String] = [ + let fields: [PostgresBackendMessage.Field : String] = [ .file: "auth.c", .routine: "auth_failed", .line: "334", @@ -25,7 +25,7 @@ class ErrorResponseTests: XCTestCase { } let expectedInOuts = [ - (buffer, [PSQLBackendMessage.error(.init(fields: fields))]), + (buffer, [PostgresBackendMessage.error(.init(fields: fields))]), ] XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( diff --git a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift index 39fbb220..f41a74af 100644 --- a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift @@ -6,7 +6,7 @@ import NIOTestUtils class NotificationResponseTests: XCTestCase { func testDecode() { - let expected: [PSQLBackendMessage] = [ + let expected: [PostgresBackendMessage] = [ .notification(.init(backendPID: 123, channel: "test", payload: "hello")), .notification(.init(backendPID: 123, channel: "test", payload: "world")), .notification(.init(backendPID: 123, channel: "foo", payload: "bar")) diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift index 8bbdae4c..5c3ff150 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift @@ -6,7 +6,7 @@ import NIOTestUtils class ParameterDescriptionTests: XCTestCase { func testDecode() { - let expected: [PSQLBackendMessage] = [ + let expected: [PostgresBackendMessage] = [ .parameterDescription(.init(dataTypes: [.bool, .varchar, .uuid, .json, .jsonbArray])), ] diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift index db4963e0..a84e2ac4 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift @@ -8,7 +8,7 @@ class ParameterStatusTests: XCTestCase { func testDecode() { var buffer = ByteBuffer() - let expected: [PSQLBackendMessage] = [ + let expected: [PostgresBackendMessage] = [ .parameterStatus(.init(parameter: "DateStyle", value: "ISO, MDY")), .parameterStatus(.init(parameter: "application_name", value: "")), .parameterStatus(.init(parameter: "server_encoding", value: "UTF8")), diff --git a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift index 55a2c1e7..8ece1bfc 100644 --- a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift @@ -8,7 +8,7 @@ class ReadyForQueryTests: XCTestCase { func testDecode() { var buffer = ByteBuffer() - let states: [PSQLBackendMessage.TransactionState] = [ + let states: [PostgresBackendMessage.TransactionState] = [ .idle, .inFailedTransaction, .inTransaction, @@ -27,7 +27,7 @@ class ReadyForQueryTests: XCTestCase { } } - let expected = states.map { state -> PSQLBackendMessage in + let expected = states.map { state -> PostgresBackendMessage in .readyForQuery(state) } @@ -67,8 +67,8 @@ class ReadyForQueryTests: XCTestCase { } func testDebugDescription() { - XCTAssertEqual(String(reflecting: PSQLBackendMessage.TransactionState.idle), ".idle") - XCTAssertEqual(String(reflecting: PSQLBackendMessage.TransactionState.inTransaction), ".inTransaction") - XCTAssertEqual(String(reflecting: PSQLBackendMessage.TransactionState.inFailedTransaction), ".inFailedTransaction") + XCTAssertEqual(String(reflecting: PostgresBackendMessage.TransactionState.idle), ".idle") + XCTAssertEqual(String(reflecting: PostgresBackendMessage.TransactionState.inTransaction), ".inTransaction") + XCTAssertEqual(String(reflecting: PostgresBackendMessage.TransactionState.inFailedTransaction), ".inFailedTransaction") } } diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index 8eba059d..7e941d54 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -11,7 +11,7 @@ class RowDescriptionTests: XCTestCase { .init(name: "Second", tableOID: 123, columnAttributeNumber: 456, dataType: .uuidArray, dataTypeSize: 567, dataTypeModifier: 123, format: .text), ] - let expected: [PSQLBackendMessage] = [ + let expected: [PostgresBackendMessage] = [ .rowDescription(.init(columns: columns)) ] diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index 049e23d1..60209d2b 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -9,59 +9,59 @@ class PSQLBackendMessageTests: XCTestCase { // MARK: ID func testInitMessageIDWithBytes() { - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "R")), .authentication) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "K")), .backendKeyData) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "2")), .bindComplete) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "3")), .closeComplete) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "C")), .commandComplete) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "d")), .copyData) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "c")), .copyDone) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "G")), .copyInResponse) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "H")), .copyOutResponse) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "W")), .copyBothResponse) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "D")), .dataRow) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "I")), .emptyQueryResponse) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "E")), .error) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "V")), .functionCallResponse) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "v")), .negotiateProtocolVersion) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "n")), .noData) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "N")), .noticeResponse) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "A")), .notificationResponse) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "t")), .parameterDescription) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "S")), .parameterStatus) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "1")), .parseComplete) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "s")), .portalSuspended) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "Z")), .readyForQuery) - XCTAssertEqual(PSQLBackendMessage.ID(rawValue: UInt8(ascii: "T")), .rowDescription) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "R")), .authentication) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "K")), .backendKeyData) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "2")), .bindComplete) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "3")), .closeComplete) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "C")), .commandComplete) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "d")), .copyData) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "c")), .copyDone) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "G")), .copyInResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "H")), .copyOutResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "W")), .copyBothResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "D")), .dataRow) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "I")), .emptyQueryResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "E")), .error) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "V")), .functionCallResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "v")), .negotiateProtocolVersion) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "n")), .noData) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "N")), .noticeResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "A")), .notificationResponse) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "t")), .parameterDescription) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "S")), .parameterStatus) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "1")), .parseComplete) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "s")), .portalSuspended) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "Z")), .readyForQuery) + XCTAssertEqual(PostgresBackendMessage.ID(rawValue: UInt8(ascii: "T")), .rowDescription) - XCTAssertNil(PSQLBackendMessage.ID(rawValue: 0)) + XCTAssertNil(PostgresBackendMessage.ID(rawValue: 0)) } func testMessageIDHasCorrectRawValue() { - XCTAssertEqual(PSQLBackendMessage.ID.authentication.rawValue, UInt8(ascii: "R")) - XCTAssertEqual(PSQLBackendMessage.ID.backendKeyData.rawValue, UInt8(ascii: "K")) - XCTAssertEqual(PSQLBackendMessage.ID.bindComplete.rawValue, UInt8(ascii: "2")) - XCTAssertEqual(PSQLBackendMessage.ID.closeComplete.rawValue, UInt8(ascii: "3")) - XCTAssertEqual(PSQLBackendMessage.ID.commandComplete.rawValue, UInt8(ascii: "C")) - XCTAssertEqual(PSQLBackendMessage.ID.copyData.rawValue, UInt8(ascii: "d")) - XCTAssertEqual(PSQLBackendMessage.ID.copyDone.rawValue, UInt8(ascii: "c")) - XCTAssertEqual(PSQLBackendMessage.ID.copyInResponse.rawValue, UInt8(ascii: "G")) - XCTAssertEqual(PSQLBackendMessage.ID.copyOutResponse.rawValue, UInt8(ascii: "H")) - XCTAssertEqual(PSQLBackendMessage.ID.copyBothResponse.rawValue, UInt8(ascii: "W")) - XCTAssertEqual(PSQLBackendMessage.ID.dataRow.rawValue, UInt8(ascii: "D")) - XCTAssertEqual(PSQLBackendMessage.ID.emptyQueryResponse.rawValue, UInt8(ascii: "I")) - XCTAssertEqual(PSQLBackendMessage.ID.error.rawValue, UInt8(ascii: "E")) - XCTAssertEqual(PSQLBackendMessage.ID.functionCallResponse.rawValue, UInt8(ascii: "V")) - XCTAssertEqual(PSQLBackendMessage.ID.negotiateProtocolVersion.rawValue, UInt8(ascii: "v")) - XCTAssertEqual(PSQLBackendMessage.ID.noData.rawValue, UInt8(ascii: "n")) - XCTAssertEqual(PSQLBackendMessage.ID.noticeResponse.rawValue, UInt8(ascii: "N")) - XCTAssertEqual(PSQLBackendMessage.ID.notificationResponse.rawValue, UInt8(ascii: "A")) - XCTAssertEqual(PSQLBackendMessage.ID.parameterDescription.rawValue, UInt8(ascii: "t")) - XCTAssertEqual(PSQLBackendMessage.ID.parameterStatus.rawValue, UInt8(ascii: "S")) - XCTAssertEqual(PSQLBackendMessage.ID.parseComplete.rawValue, UInt8(ascii: "1")) - XCTAssertEqual(PSQLBackendMessage.ID.portalSuspended.rawValue, UInt8(ascii: "s")) - XCTAssertEqual(PSQLBackendMessage.ID.readyForQuery.rawValue, UInt8(ascii: "Z")) - XCTAssertEqual(PSQLBackendMessage.ID.rowDescription.rawValue, UInt8(ascii: "T")) + XCTAssertEqual(PostgresBackendMessage.ID.authentication.rawValue, UInt8(ascii: "R")) + XCTAssertEqual(PostgresBackendMessage.ID.backendKeyData.rawValue, UInt8(ascii: "K")) + XCTAssertEqual(PostgresBackendMessage.ID.bindComplete.rawValue, UInt8(ascii: "2")) + XCTAssertEqual(PostgresBackendMessage.ID.closeComplete.rawValue, UInt8(ascii: "3")) + XCTAssertEqual(PostgresBackendMessage.ID.commandComplete.rawValue, UInt8(ascii: "C")) + XCTAssertEqual(PostgresBackendMessage.ID.copyData.rawValue, UInt8(ascii: "d")) + XCTAssertEqual(PostgresBackendMessage.ID.copyDone.rawValue, UInt8(ascii: "c")) + XCTAssertEqual(PostgresBackendMessage.ID.copyInResponse.rawValue, UInt8(ascii: "G")) + XCTAssertEqual(PostgresBackendMessage.ID.copyOutResponse.rawValue, UInt8(ascii: "H")) + XCTAssertEqual(PostgresBackendMessage.ID.copyBothResponse.rawValue, UInt8(ascii: "W")) + XCTAssertEqual(PostgresBackendMessage.ID.dataRow.rawValue, UInt8(ascii: "D")) + XCTAssertEqual(PostgresBackendMessage.ID.emptyQueryResponse.rawValue, UInt8(ascii: "I")) + XCTAssertEqual(PostgresBackendMessage.ID.error.rawValue, UInt8(ascii: "E")) + XCTAssertEqual(PostgresBackendMessage.ID.functionCallResponse.rawValue, UInt8(ascii: "V")) + XCTAssertEqual(PostgresBackendMessage.ID.negotiateProtocolVersion.rawValue, UInt8(ascii: "v")) + XCTAssertEqual(PostgresBackendMessage.ID.noData.rawValue, UInt8(ascii: "n")) + XCTAssertEqual(PostgresBackendMessage.ID.noticeResponse.rawValue, UInt8(ascii: "N")) + XCTAssertEqual(PostgresBackendMessage.ID.notificationResponse.rawValue, UInt8(ascii: "A")) + XCTAssertEqual(PostgresBackendMessage.ID.parameterDescription.rawValue, UInt8(ascii: "t")) + XCTAssertEqual(PostgresBackendMessage.ID.parameterStatus.rawValue, UInt8(ascii: "S")) + XCTAssertEqual(PostgresBackendMessage.ID.parseComplete.rawValue, UInt8(ascii: "1")) + XCTAssertEqual(PostgresBackendMessage.ID.portalSuspended.rawValue, UInt8(ascii: "s")) + XCTAssertEqual(PostgresBackendMessage.ID.readyForQuery.rawValue, UInt8(ascii: "Z")) + XCTAssertEqual(PostgresBackendMessage.ID.rowDescription.rawValue, UInt8(ascii: "T")) } // MARK: Decoder @@ -70,11 +70,11 @@ class PSQLBackendMessageTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "S")) - var expectedMessages: [PSQLBackendMessage] = [.sslSupported] + var expectedMessages: [PostgresBackendMessage] = [.sslSupported] // we test tons of ParameterStatus messages after the SSLSupported message, since those are // also identified by an "S" - let parameterStatus: [PSQLBackendMessage.ParameterStatus] = [ + let parameterStatus: [PostgresBackendMessage.ParameterStatus] = [ .init(parameter: "DateStyle", value: "ISO, MDY"), .init(parameter: "application_name", value: ""), .init(parameter: "server_encoding", value: "UTF8"), @@ -102,8 +102,8 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertNoThrow(try embedded.writeInbound(buffer)) for expected in expectedMessages { - var message: PSQLBackendMessage? - XCTAssertNoThrow(message = try embedded.readInbound(as: PSQLBackendMessage.self)) + var message: PostgresBackendMessage? + XCTAssertNoThrow(message = try embedded.readInbound(as: PostgresBackendMessage.self)) XCTAssertEqual(message, expected) } } @@ -114,7 +114,7 @@ class PSQLBackendMessageTests: XCTestCase { // we test a NoticeResponse messages after the SSLUnupported message, since NoticeResponse // is identified by a "N" - let fields: [PSQLBackendMessage.Field : String] = [ + let fields: [PostgresBackendMessage.Field : String] = [ .file: "auth.c", .routine: "auth_failed", .line: "334", @@ -124,7 +124,7 @@ class PSQLBackendMessageTests: XCTestCase { .message: "password authentication failed for user \"postgre3\"", ] - let expectedMessages: [PSQLBackendMessage] = [ + let expectedMessages: [PostgresBackendMessage] = [ .sslUnsupported, .notice(.init(fields: fields)) ] @@ -142,14 +142,14 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertNoThrow(try embedded.writeInbound(buffer)) for expected in expectedMessages { - var message: PSQLBackendMessage? - XCTAssertNoThrow(message = try embedded.readInbound(as: PSQLBackendMessage.self)) + var message: PostgresBackendMessage? + XCTAssertNoThrow(message = try embedded.readInbound(as: PostgresBackendMessage.self)) XCTAssertEqual(message, expected) } } func testPayloadsWithoutAssociatedValues() { - let messageIDs: [PSQLBackendMessage.ID] = [ + let messageIDs: [PostgresBackendMessage.ID] = [ .bindComplete, .closeComplete, .emptyQueryResponse, @@ -163,7 +163,7 @@ class PSQLBackendMessageTests: XCTestCase { buffer.writeBackendMessage(id: messageID) { _ in } } - let expected: [PSQLBackendMessage] = [ + let expected: [PostgresBackendMessage] = [ .bindComplete, .closeComplete, .emptyQueryResponse, @@ -178,7 +178,7 @@ class PSQLBackendMessageTests: XCTestCase { } func testPayloadsWithoutAssociatedValuesInvalidLength() { - let messageIDs: [PSQLBackendMessage.ID] = [ + let messageIDs: [PostgresBackendMessage.ID] = [ .bindComplete, .closeComplete, .emptyQueryResponse, @@ -202,7 +202,7 @@ class PSQLBackendMessageTests: XCTestCase { } func testDecodeCommandCompleteMessage() { - let expected: [PSQLBackendMessage] = [ + let expected: [PostgresBackendMessage] = [ .commandComplete("SELECT 100"), .commandComplete("INSERT 0 1"), .commandComplete("UPDATE 1"), @@ -256,40 +256,40 @@ class PSQLBackendMessageTests: XCTestCase { } func testDebugDescription() { - XCTAssertEqual("\(PSQLBackendMessage.authentication(.ok))", ".authentication(.ok)") - XCTAssertEqual("\(PSQLBackendMessage.authentication(.kerberosV5))", + XCTAssertEqual("\(PostgresBackendMessage.authentication(.ok))", ".authentication(.ok)") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.kerberosV5))", ".authentication(.kerberosV5)") - XCTAssertEqual("\(PSQLBackendMessage.authentication(.md5(salt: (0, 1, 2, 3))))", + XCTAssertEqual("\(PostgresBackendMessage.authentication(.md5(salt: (0, 1, 2, 3))))", ".authentication(.md5(salt: (0, 1, 2, 3)))") - XCTAssertEqual("\(PSQLBackendMessage.authentication(.plaintext))", + XCTAssertEqual("\(PostgresBackendMessage.authentication(.plaintext))", ".authentication(.plaintext)") - XCTAssertEqual("\(PSQLBackendMessage.authentication(.scmCredential))", + XCTAssertEqual("\(PostgresBackendMessage.authentication(.scmCredential))", ".authentication(.scmCredential)") - XCTAssertEqual("\(PSQLBackendMessage.authentication(.gss))", + XCTAssertEqual("\(PostgresBackendMessage.authentication(.gss))", ".authentication(.gss)") - XCTAssertEqual("\(PSQLBackendMessage.authentication(.sspi))", + XCTAssertEqual("\(PostgresBackendMessage.authentication(.sspi))", ".authentication(.sspi)") - XCTAssertEqual("\(PSQLBackendMessage.parameterStatus(.init(parameter: "foo", value: "bar")))", + XCTAssertEqual("\(PostgresBackendMessage.parameterStatus(.init(parameter: "foo", value: "bar")))", #".parameterStatus(parameter: "foo", value: "bar")"#) - XCTAssertEqual("\(PSQLBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567)))", + XCTAssertEqual("\(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567)))", ".backendKeyData(processID: 1234, secretKey: 4567)") - XCTAssertEqual("\(PSQLBackendMessage.bindComplete)", ".bindComplete") - XCTAssertEqual("\(PSQLBackendMessage.closeComplete)", ".closeComplete") - XCTAssertEqual("\(PSQLBackendMessage.commandComplete("SELECT 123"))", #".commandComplete("SELECT 123")"#) - XCTAssertEqual("\(PSQLBackendMessage.emptyQueryResponse)", ".emptyQueryResponse") - XCTAssertEqual("\(PSQLBackendMessage.noData)", ".noData") - XCTAssertEqual("\(PSQLBackendMessage.parseComplete)", ".parseComplete") - XCTAssertEqual("\(PSQLBackendMessage.portalSuspended)", ".portalSuspended") + XCTAssertEqual("\(PostgresBackendMessage.bindComplete)", ".bindComplete") + XCTAssertEqual("\(PostgresBackendMessage.closeComplete)", ".closeComplete") + XCTAssertEqual("\(PostgresBackendMessage.commandComplete("SELECT 123"))", #".commandComplete("SELECT 123")"#) + XCTAssertEqual("\(PostgresBackendMessage.emptyQueryResponse)", ".emptyQueryResponse") + XCTAssertEqual("\(PostgresBackendMessage.noData)", ".noData") + XCTAssertEqual("\(PostgresBackendMessage.parseComplete)", ".parseComplete") + XCTAssertEqual("\(PostgresBackendMessage.portalSuspended)", ".portalSuspended") - XCTAssertEqual("\(PSQLBackendMessage.readyForQuery(.idle))", ".readyForQuery(.idle)") - XCTAssertEqual("\(PSQLBackendMessage.readyForQuery(.inTransaction))", + XCTAssertEqual("\(PostgresBackendMessage.readyForQuery(.idle))", ".readyForQuery(.idle)") + XCTAssertEqual("\(PostgresBackendMessage.readyForQuery(.inTransaction))", ".readyForQuery(.inTransaction)") - XCTAssertEqual("\(PSQLBackendMessage.readyForQuery(.inFailedTransaction))", + XCTAssertEqual("\(PostgresBackendMessage.readyForQuery(.inFailedTransaction))", ".readyForQuery(.inFailedTransaction)") - XCTAssertEqual("\(PSQLBackendMessage.sslSupported)", ".sslSupported") - XCTAssertEqual("\(PSQLBackendMessage.sslUnsupported)", ".sslUnsupported") + XCTAssertEqual("\(PostgresBackendMessage.sslSupported)", ".sslSupported") + XCTAssertEqual("\(PostgresBackendMessage.sslUnsupported)", ".sslUnsupported") } } diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index 8085c326..52e4f39c 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -31,9 +31,9 @@ class PSQLChannelHandlerTests: XCTestCase { XCTAssertEqual(startup.parameters.options, nil) XCTAssertEqual(startup.parameters.replication, .false) - XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.authentication(.ok))) - XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))) - XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.readyForQuery(.idle))) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.ok))) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.readyForQuery(.idle))) } func testEstablishSSLCallbackIsCalledIfSSLIsSupported() { @@ -58,7 +58,7 @@ class PSQLChannelHandlerTests: XCTestCase { XCTAssertEqual(request.code, 80877103) - XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.sslSupported)) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.sslSupported)) // a NIOSSLHandler has been added, after it SSL had been negotiated XCTAssertTrue(addSSLCallbackIsHit) @@ -99,7 +99,7 @@ class PSQLChannelHandlerTests: XCTestCase { // read the ssl request message XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .sslRequest(.init())) - XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.sslUnsupported)) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.sslUnsupported)) // the event handler should have seen an error XCTAssertEqual(eventHandler.errors.count, 1) @@ -128,7 +128,7 @@ class PSQLChannelHandlerTests: XCTestCase { embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) - XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.authentication(.md5(salt: (0,1,2,3))))) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.md5(salt: (0,1,2,3))))) var message: PostgresFrontendMessage? XCTAssertNoThrow(message = try embedded.readOutbound(as: PostgresFrontendMessage.self)) @@ -157,7 +157,7 @@ class PSQLChannelHandlerTests: XCTestCase { embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) - XCTAssertNoThrow(try embedded.writeInbound(PSQLBackendMessage.authentication(.plaintext))) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.plaintext))) var message: PostgresFrontendMessage? XCTAssertNoThrow(message = try embedded.readOutbound(as: PostgresFrontendMessage.self)) From d16467d507829a5827953eb2bbf4473e4ff17575 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 2 Mar 2022 23:57:43 +0100 Subject: [PATCH 075/292] Merge PSQLConnection into PostgresConnection (#240) --- .../Connection/PostgresConnection.swift | 362 ++++++++++++++++-- .../ConnectionStateMachine.swift | 43 ++- .../New/Extensions/Logging+PSQL.swift | 3 + .../PostgresNIO/New/PSQLChannelHandler.swift | 12 +- Sources/PostgresNIO/New/PSQLConnection.swift | 304 --------------- .../PostgresNIO/New/PSQLEventsHandler.swift | 10 +- .../New/PSQLPreparedStatement.swift | 2 +- .../PSQLIntegrationTests.swift | 81 ++-- Tests/IntegrationTests/PostgresNIOTests.swift | 2 +- .../New/PSQLChannelHandlerTests.swift | 6 +- .../New/PSQLConnectionTests.swift | 4 +- 11 files changed, 405 insertions(+), 424 deletions(-) delete mode 100644 Sources/PostgresNIO/New/PSQLConnection.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index be7e6c97..58eb621e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -1,52 +1,306 @@ import NIOCore +import NIOConcurrencyHelpers import NIOSSL import Logging +import NIOPosix public final class PostgresConnection { - let underlying: PSQLConnection - + typealias ID = Int + + struct Configuration { + struct Authentication { + var username: String + var database: String? = nil + var password: String? = nil + + init(username: String, password: String?, database: String?) { + self.username = username + self.database = database + self.password = password + } + } + + struct TLS { + enum Base { + case disable + case prefer(NIOSSLContext) + case require(NIOSSLContext) + } + + var base: Base + + private init(_ base: Base) { + self.base = base + } + + static var disable: Self = Self.init(.disable) + + static func prefer(_ sslContext: NIOSSLContext) -> Self { + self.init(.prefer(sslContext)) + } + + static func require(_ sslContext: NIOSSLContext) -> Self { + self.init(.require(sslContext)) + } + } + + enum Connection { + case unresolved(host: String, port: Int) + case resolved(address: SocketAddress, serverName: String?) + } + + var connection: Connection + + /// The authentication properties to send to the Postgres server during startup auth handshake + var authentication: Authentication? + + var tls: TLS + + init(host: String, + port: Int = 5432, + username: String, + database: String? = nil, + password: String? = nil, + tls: TLS = .disable + ) { + self.connection = .unresolved(host: host, port: port) + self.authentication = Authentication(username: username, password: password, database: database) + self.tls = tls + } + + init(connection: Connection, + authentication: Authentication?, + tls: TLS + ) { + self.connection = connection + self.authentication = authentication + self.tls = tls + } + } + + /// The connection's underlying channel + /// + /// This should be private, but it is needed for `PostgresConnection` compatibility. + internal let channel: Channel + + /// The underlying `EventLoop` of both the connection and its channel. public var eventLoop: EventLoop { - return self.underlying.eventLoop + return self.channel.eventLoop } - + public var closeFuture: EventLoopFuture { - return self.underlying.channel.closeFuture + return self.channel.closeFuture } - - /// A logger to use in case - public var logger: Logger - + + /// A logger to use in case + public var logger: Logger { + get { + self._logger + } + set { + // ignore + } + } + /// A dictionary to store notification callbacks in /// /// Those are used when `PostgresConnection.addListener` is invoked. This only lives here since properties /// can not be added in extensions. All relevant code lives in `PostgresConnection+Notifications` var notificationListeners: [String: [(PostgresListenContext, (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void)]] = [:] { willSet { - self.underlying.channel.eventLoop.preconditionInEventLoop() + self.channel.eventLoop.preconditionInEventLoop() } } public var isClosed: Bool { - return !self.underlying.channel.isActive + return !self.channel.isActive } - - init(underlying: PSQLConnection, logger: Logger) { - self.underlying = underlying - self.logger = logger - - self.underlying.channel.pipeline.handler(type: PSQLChannelHandler.self).whenSuccess { handler in - handler.notificationDelegate = self + + let id: ID + + private var _logger: Logger + + init(channel: Channel, connectionID: ID, logger: Logger) { + self.channel = channel + self.id = connectionID + self._logger = logger + } + deinit { + assert(self.isClosed, "PostgresConnection deinitialized before being closed.") + } + + func start(configuration: Configuration) -> EventLoopFuture { + // 1. configure handlers + + var configureSSLCallback: ((Channel) throws -> ())? = nil + switch configuration.tls.base { + case .disable: + break + + case .prefer(let sslContext), .require(let sslContext): + configureSSLCallback = { channel in + channel.eventLoop.assertInEventLoop() + + let sslHandler = try NIOSSLClientHandler( + context: sslContext, + serverHostname: configuration.sslServerHostname + ) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) + } + } + + let channelHandler = PSQLChannelHandler( + configuration: configuration, + logger: logger, + configureSSLCallback: configureSSLCallback + ) + channelHandler.notificationDelegate = self + + let eventHandler = PSQLEventsHandler(logger: logger) + + // 2. add handlers + + do { + try self.channel.pipeline.syncOperations.addHandler(eventHandler) + try self.channel.pipeline.syncOperations.addHandler(channelHandler, position: .before(eventHandler)) + } catch { + return self.eventLoop.makeFailedFuture(error) + } + + let startupFuture: EventLoopFuture + if configuration.authentication == nil { + startupFuture = eventHandler.readyForStartupFuture + } else { + startupFuture = eventHandler.authenticateFuture + } + + // 3. wait for startup future to succeed. + + return startupFuture.flatMapError { error in + // in case of an startup error, the connection must be closed and after that + // the originating error should be surfaced + + self.channel.closeFuture.flatMapThrowing { _ in + throw error + } + } + } + + static func connect( + connectionID: ID, + configuration: PostgresConnection.Configuration, + logger: Logger, + on eventLoop: EventLoop + ) -> EventLoopFuture { + + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(connectionID)" + + // Here we dispatch to the `eventLoop` first before we setup the EventLoopFuture chain, to + // ensure all `flatMap`s are executed on the EventLoop (this means the enqueuing of the + // callbacks). + // + // This saves us a number of context switches between the thread the Connection is created + // on and the EventLoop. In addition, it eliminates all potential races between the creating + // thread and the EventLoop. + return eventLoop.flatSubmit { () -> EventLoopFuture in + let connectFuture: EventLoopFuture + + switch configuration.connection { + case .resolved(let address, _): + connectFuture = ClientBootstrap(group: eventLoop).connect(to: address) + case .unresolved(let host, let port): + connectFuture = ClientBootstrap(group: eventLoop).connect(host: host, port: port) + } + + return connectFuture.flatMap { channel -> EventLoopFuture in + let connection = PostgresConnection(channel: channel, connectionID: connectionID, logger: logger) + return connection.start(configuration: configuration).map { _ in connection } + }.flatMapErrorThrowing { error -> PostgresConnection in + switch error { + case is PSQLError: + throw error + default: + throw PSQLError.channel(underlying: error) + } + } } } - + + // MARK: Query + + func query(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.id)" + guard query.binds.count <= Int(Int16.max) else { + return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) + } + + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + logger: logger, + promise: promise) + + self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + + return promise.futureResult + } + + // MARK: Prepared statements + + func prepareStatement(_ query: String, with name: String, logger: Logger) -> EventLoopFuture { + let promise = self.channel.eventLoop.makePromise(of: RowDescription?.self) + let context = PrepareStatementContext( + name: name, + query: query, + logger: logger, + promise: promise) + + self.channel.write(PSQLTask.preparedStatement(context), promise: nil) + return promise.futureResult.map { rowDescription in + PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) + } + } + + func execute(_ executeStatement: PSQLExecuteStatement, logger: Logger) -> EventLoopFuture { + guard executeStatement.binds.count <= Int(Int16.max) else { + return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) + } + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + executeStatement: executeStatement, + logger: logger, + promise: promise) + + self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + return promise.futureResult + } + + func close(_ target: CloseTarget, logger: Logger) -> EventLoopFuture { + let promise = self.channel.eventLoop.makePromise(of: Void.self) + let context = CloseCommandContext(target: target, logger: logger, promise: promise) + + self.channel.write(PSQLTask.closeCommand(context), promise: nil) + return promise.futureResult + } + + public func close() -> EventLoopFuture { - return self.underlying.close() + guard !self.isClosed else { + return self.eventLoop.makeSucceededFuture(()) + } + + self.channel.close(mode: .all, promise: nil) + return self.closeFuture } } // MARK: Connect extension PostgresConnection { + static let idGenerator = NIOAtomic.makeAtomic(value: 0) + public static func connect( to socketAddress: SocketAddress, tlsConfiguration: TLSConfiguration? = nil, @@ -54,30 +308,29 @@ extension PostgresConnection { logger: Logger = .init(label: "codes.vapor.postgres"), on eventLoop: EventLoop ) -> EventLoopFuture { - var tlsFuture: EventLoopFuture + var tlsFuture: EventLoopFuture if let tlsConfiguration = tlsConfiguration { tlsFuture = eventLoop.makeSucceededVoidFuture().flatMapBlocking(onto: .global(qos: .default)) { - try PSQLConnection.Configuration.TLS.require(.init(configuration: tlsConfiguration)) + try PostgresConnection.Configuration.TLS.require(.init(configuration: tlsConfiguration)) } } else { tlsFuture = eventLoop.makeSucceededFuture(.disable) } return tlsFuture.flatMap { tls in - let configuration = PSQLConnection.Configuration( + let configuration = PostgresConnection.Configuration( connection: .resolved(address: socketAddress, serverName: serverHostname), authentication: nil, tls: tls ) - return PSQLConnection.connect( + return PostgresConnection.connect( + connectionID: idGenerator.add(1), configuration: configuration, logger: logger, on: eventLoop ) - }.map { connection in - PostgresConnection(underlying: connection, logger: logger) }.flatMapErrorThrowing { error in throw error.asAppropriatePostgresError } @@ -94,9 +347,9 @@ extension PostgresConnection { password: password, database: database) let outgoing = PSQLOutgoingEvent.authenticate(authContext) - self.underlying.channel.triggerUserOutboundEvent(outgoing, promise: nil) + self.channel.triggerUserOutboundEvent(outgoing, promise: nil) - return self.underlying.channel.pipeline.handler(type: PSQLEventsHandler.self).flatMap { handler in + return self.channel.pipeline.handler(type: PSQLEventsHandler.self).flatMap { handler in handler.authenticateFuture }.flatMapErrorThrowing { error in throw error.asAppropriatePostgresError @@ -112,19 +365,19 @@ extension PostgresConnection { func query(_ query: PostgresQuery, logger: Logger, file: String = #file, line: UInt = #line) async throws -> PostgresRowSequence { var logger = logger - logger[postgresMetadataKey: .connectionID] = "\(self.underlying.connectionID)" + logger[postgresMetadataKey: .connectionID] = "\(self.id)" do { guard query.binds.count <= Int(Int16.max) else { throw PSQLError.tooManyParameters } - let promise = self.underlying.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let context = ExtendedQueryContext( query: query, logger: logger, promise: promise) - self.underlying.channel.write(PSQLTask.extendedQuery(context), promise: nil) + self.channel.write(PSQLTask.extendedQuery(context), promise: nil) return try await promise.futureResult.map({ $0.asyncSequence() }).get() } @@ -147,21 +400,21 @@ extension PostgresConnection: PostgresDatabase { switch command { case .query(let query, let onMetadata, let onRow): - resultFuture = self.underlying.query(query, logger: logger).flatMap { stream in + resultFuture = self.query(query, logger: logger).flatMap { stream in return stream.onRow(onRow).map { _ in onMetadata(PostgresQueryMetadata(string: stream.commandTag)!) } } case .queryAll(let query, let onResult): - resultFuture = self.underlying.query(query, logger: logger).flatMap { rows in + resultFuture = self.query(query, logger: logger).flatMap { rows in return rows.all().map { allrows in onResult(.init(metadata: PostgresQueryMetadata(string: rows.commandTag)!, rows: allrows)) } } case .prepareQuery(let request): - resultFuture = self.underlying.prepareStatement(request.query, with: request.name, logger: self.logger).map { + resultFuture = self.prepareStatement(request.query, with: request.name, logger: self.logger).map { request.prepared = PreparedQuery(underlying: $0, database: self) } @@ -175,7 +428,7 @@ extension PostgresConnection: PostgresDatabase { rowDescription: preparedQuery.underlying.rowDescription ) - resultFuture = self.underlying.execute(statement, logger: logger).flatMap { rows in + resultFuture = self.execute(statement, logger: logger).flatMap { rows in return rows.onRow(onRow) } } @@ -231,7 +484,7 @@ extension PostgresConnection { let listenContext = PostgresListenContext() - self.underlying.channel.pipeline.handler(type: PSQLChannelHandler.self).whenSuccess { handler in + self.channel.pipeline.handler(type: PSQLChannelHandler.self).whenSuccess { handler in if self.notificationListeners[channel] != nil { self.notificationListeners[channel]!.append((listenContext, notificationHandler)) } @@ -244,7 +497,7 @@ extension PostgresConnection { // self is weak, since the connection can long be gone, when the listeners stop is // triggered. listenContext must be weak to prevent a retain cycle - self?.underlying.channel.eventLoop.execute { + self?.channel.eventLoop.execute { guard let self = self, // the connection is already gone var listeners = self.notificationListeners[channel] // we don't have the listeners for this topic ¯\_(ツ)_/¯ @@ -264,7 +517,7 @@ extension PostgresConnection { extension PostgresConnection: PSQLChannelHandlerNotificationDelegate { func notificationReceived(_ notification: PostgresBackendMessage.NotificationResponse) { - self.underlying.eventLoop.assertInEventLoop() + self.eventLoop.assertInEventLoop() guard let listeners = self.notificationListeners[notification.channel] else { return @@ -280,3 +533,36 @@ extension PostgresConnection: PSQLChannelHandlerNotificationDelegate { } } } + +enum CloseTarget { + case preparedStatement(String) + case portal(String) +} + +extension PostgresConnection.Configuration { + var sslServerHostname: String? { + switch self.connection { + case .unresolved(let host, _): + guard !host.isIPAddress() else { + return nil + } + return host + case .resolved(_, let serverName): + return serverName + } + } +} + +// copy and pasted from NIOSSL: +private extension String { + func isIPAddress() -> Bool { + // We need some scratch space to let inet_pton write into. + var ipv4Addr = in_addr() + var ipv6Addr = in6_addr() + + return self.withCString { ptr in + return inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || + inet_pton(AF_INET6, ptr, &ipv6Addr) == 1 + } + } +} diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 4a1a2813..fa00328b 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -134,22 +134,37 @@ struct ConnectionStateMachine { } mutating func connected(tls: TLSConfiguration) -> ConnectionAction { - guard case .initialized = self.state else { - preconditionFailure("Unexpected state") - } + switch self.state { + case .initialized: + switch tls { + case .disable: + self.state = .waitingToStartAuthentication + return .provideAuthenticationContext - switch tls { - case .disable: - self.state = .waitingToStartAuthentication - return .provideAuthenticationContext + case .prefer: + self.state = .sslRequestSent(.prefer) + return .sendSSLRequest - case .prefer: - self.state = .sslRequestSent(.prefer) - return .sendSSLRequest + case .require: + self.state = .sslRequestSent(.require) + return .sendSSLRequest + } - case .require: - self.state = .sslRequestSent(.require) - return .sendSSLRequest + case .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .extendedQuery, + .prepareStatement, + .closeCommand, + .error, + .closing, + .closed, + .modifying: + return .wait } } @@ -1084,7 +1099,7 @@ extension ConnectionStateMachine { case .tooManyParameters: return true case .connectionQuiescing: - preconditionFailure("Pure client error, that is thrown directly in PSQLConnection") + preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") case .connectionClosed: preconditionFailure("Pure client error, that is thrown directly and should never ") case .connectionError: diff --git a/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift b/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift index 90e91177..ed83e84d 100644 --- a/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift @@ -1,5 +1,8 @@ import Logging +@usableFromInline +enum PSQLConnection {} + extension PSQLConnection { @usableFromInline enum LoggerMetaDataKey: String { diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index c39537d6..812cd358 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -26,13 +26,13 @@ final class PSQLChannelHandler: ChannelDuplexHandler { private var rowStream: PSQLRowStream? private var decoder: NIOSingleStepByteToMessageProcessor private var encoder: BufferedMessageEncoder! - private let configuration: PSQLConnection.Configuration + private let configuration: PostgresConnection.Configuration private let configureSSLCallback: ((Channel) throws -> Void)? /// this delegate should only be accessed on the connections `EventLoop` weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate? - init(configuration: PSQLConnection.Configuration, + init(configuration: PostgresConnection.Configuration, logger: Logger, configureSSLCallback: ((Channel) throws -> Void)?) { @@ -45,7 +45,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { #if DEBUG /// for testing purposes only - init(configuration: PSQLConnection.Configuration, + init(configuration: PostgresConnection.Configuration, state: ConnectionStateMachine = .init(.initialized), logger: Logger = .psqlNoOpLogger, configureSSLCallback: ((Channel) throws -> Void)?) @@ -518,7 +518,7 @@ extension PSQLChannelHandler: PSQLRowsDataSource { } } -extension PSQLConnection.Configuration.Authentication { +extension PostgresConnection.Configuration.Authentication { func toAuthContext() -> AuthContext { AuthContext( username: self.username, @@ -575,7 +575,7 @@ private extension Insecure.MD5.Digest { } extension ConnectionStateMachine.TLSConfiguration { - fileprivate init(_ connection: PSQLConnection.Configuration.TLS) { + fileprivate init(_ connection: PostgresConnection.Configuration.TLS) { switch connection.base { case .disable: self = .disable @@ -589,7 +589,7 @@ extension ConnectionStateMachine.TLSConfiguration { extension PSQLChannelHandler { convenience init( - configuration: PSQLConnection.Configuration, + configuration: PostgresConnection.Configuration, configureSSLCallback: ((Channel) throws -> Void)?) { self.init( diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift deleted file mode 100644 index 0b1ce1ab..00000000 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ /dev/null @@ -1,304 +0,0 @@ -import NIOCore -import NIOPosix -import NIOFoundationCompat -import NIOSSL -import class Foundation.JSONEncoder -import class Foundation.JSONDecoder -import struct Foundation.UUID -import Logging - -@usableFromInline -final class PSQLConnection { - - struct Configuration { - - struct Authentication { - var username: String - var database: String? = nil - var password: String? = nil - - init(username: String, password: String?, database: String?) { - self.username = username - self.database = database - self.password = password - } - } - - struct TLS { - enum Base { - case disable - case prefer(NIOSSLContext) - case require(NIOSSLContext) - } - - var base: Base - - private init(_ base: Base) { - self.base = base - } - - static var disable: Self = Self.init(.disable) - - static func prefer(_ sslContext: NIOSSLContext) -> Self { - self.init(.prefer(sslContext)) - } - - static func require(_ sslContext: NIOSSLContext) -> Self { - self.init(.require(sslContext)) - } - } - - enum Connection { - case unresolved(host: String, port: Int) - case resolved(address: SocketAddress, serverName: String?) - } - - var connection: Connection - - /// The authentication properties to send to the Postgres server during startup auth handshake - var authentication: Authentication? - - var tls: TLS - - init(host: String, - port: Int = 5432, - username: String, - database: String? = nil, - password: String? = nil, - tls: TLS = .disable - ) { - self.connection = .unresolved(host: host, port: port) - self.authentication = Authentication(username: username, password: password, database: database) - self.tls = tls - } - - init(connection: Connection, - authentication: Authentication?, - tls: TLS - ) { - self.connection = connection - self.authentication = authentication - self.tls = tls - } - } - - /// The connection's underlying channel - /// - /// This should be private, but it is needed for `PostgresConnection` compatibility. - internal let channel: Channel - - /// The underlying `EventLoop` of both the connection and its channel. - var eventLoop: EventLoop { - return self.channel.eventLoop - } - - var closeFuture: EventLoopFuture { - return self.channel.closeFuture - } - - var isClosed: Bool { - return !self.channel.isActive - } - - /// A logger to use in case - private var logger: Logger - let connectionID: String - - init(channel: Channel, connectionID: String, logger: Logger) { - self.channel = channel - self.connectionID = connectionID - self.logger = logger - } - deinit { - assert(self.isClosed, "PostgresConnection deinitialized before being closed.") - } - - func close() -> EventLoopFuture { - guard !self.isClosed else { - return self.eventLoop.makeSucceededFuture(()) - } - - self.channel.close(mode: .all, promise: nil) - return self.closeFuture - } - - // MARK: Query - - func query(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture { - var logger = logger - logger[postgresMetadataKey: .connectionID] = "\(self.connectionID)" - guard query.binds.count <= Int(Int16.max) else { - return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) - } - - let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) - let context = ExtendedQueryContext( - query: query, - logger: logger, - promise: promise) - - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) - - return promise.futureResult - } - - // MARK: Prepared statements - - func prepareStatement(_ query: String, with name: String, logger: Logger) -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: RowDescription?.self) - let context = PrepareStatementContext( - name: name, - query: query, - logger: logger, - promise: promise) - - self.channel.write(PSQLTask.preparedStatement(context), promise: nil) - return promise.futureResult.map { rowDescription in - PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) - } - } - - func execute(_ executeStatement: PSQLExecuteStatement, logger: Logger) -> EventLoopFuture { - guard executeStatement.binds.count <= Int(Int16.max) else { - return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) - } - let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) - let context = ExtendedQueryContext( - executeStatement: executeStatement, - logger: logger, - promise: promise) - - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) - return promise.futureResult - } - - func close(_ target: CloseTarget, logger: Logger) -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: Void.self) - let context = CloseCommandContext(target: target, logger: logger, promise: promise) - - self.channel.write(PSQLTask.closeCommand(context), promise: nil) - return promise.futureResult - } - - static func connect( - configuration: PSQLConnection.Configuration, - logger: Logger, - on eventLoop: EventLoop - ) -> EventLoopFuture { - - let connectionID = UUID().uuidString - var logger = logger - logger[postgresMetadataKey: .connectionID] = "\(connectionID)" - - // Here we dispatch to the `eventLoop` first before we setup the EventLoopFuture chain, to - // ensure all `flatMap`s are executed on the EventLoop (this means the enqueuing of the - // callbacks). - // - // This saves us a number of context switches between the thread the Connection is created - // on and the EventLoop. In addition, it eliminates all potential races between the creating - // thread and the EventLoop. - return eventLoop.flatSubmit { - eventLoop.submit { () throws -> SocketAddress in - switch configuration.connection { - case .resolved(let address, _): - return address - case .unresolved(let host, let port): - return try SocketAddress.makeAddressResolvingHost(host, port: port) - } - }.flatMap { address -> EventLoopFuture in - let bootstrap = ClientBootstrap(group: eventLoop) - .channelInitializer { channel in - var configureSSLCallback: ((Channel) throws -> ())? = nil - - switch configuration.tls.base { - case .disable: - break - - case .prefer(let sslContext), .require(let sslContext): - configureSSLCallback = { channel in - channel.eventLoop.assertInEventLoop() - - let sslHandler = try NIOSSLClientHandler( - context: sslContext, - serverHostname: configuration.sslServerHostname - ) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) - } - } - - return channel.pipeline.addHandlers([ - PSQLChannelHandler( - configuration: configuration, - logger: logger, - configureSSLCallback: configureSSLCallback), - PSQLEventsHandler(logger: logger) - ]) - } - - return bootstrap.connect(to: address) - }.flatMap { channel -> EventLoopFuture in - channel.pipeline.handler(type: PSQLEventsHandler.self).flatMap { - eventHandler -> EventLoopFuture in - - let startupFuture: EventLoopFuture - if configuration.authentication == nil { - startupFuture = eventHandler.readyForStartupFuture - } else { - startupFuture = eventHandler.authenticateFuture - } - - return startupFuture.flatMapError { error in - // in case of an startup error, the connection must be closed and after that - // the originating error should be surfaced - - channel.closeFuture.flatMapThrowing { _ in - throw error - } - } - }.map { _ in channel } - }.map { channel in - PSQLConnection(channel: channel, connectionID: connectionID, logger: logger) - }.flatMapErrorThrowing { error -> PSQLConnection in - switch error { - case is PSQLError: - throw error - default: - throw PSQLError.channel(underlying: error) - } - } - } - } -} - -enum CloseTarget { - case preparedStatement(String) - case portal(String) -} - -extension PSQLConnection.Configuration { - var sslServerHostname: String? { - switch self.connection { - case .unresolved(let host, _): - guard !host.isIPAddress() else { - return nil - } - return host - case .resolved(_, let serverName): - return serverName - } - } -} - -// copy and pasted from NIOSSL: -private extension String { - func isIPAddress() -> Bool { - // We need some scratch space to let inet_pton write into. - var ipv4Addr = in_addr() - var ipv6Addr = in6_addr() - - return self.withCString { ptr in - return inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || - inet_pton(AF_INET6, ptr, &ipv6Addr) == 1 - } - } -} diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift index 2c9aeaa1..0318061e 100644 --- a/Sources/PostgresNIO/New/PSQLEventsHandler.swift +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -74,14 +74,18 @@ final class PSQLEventsHandler: ChannelInboundHandler { } func handlerAdded(context: ChannelHandlerContext) { - precondition(!context.channel.isActive) - self.readyForStartupPromise = context.eventLoop.makePromise(of: Void.self) self.authenticatePromise = context.eventLoop.makePromise(of: Void.self) + + if context.channel.isActive, case .initialized = self.state { + self.state = .connected + } } func channelActive(context: ChannelHandlerContext) { - self.state = .connected + if case .initialized = self.state { + self.state = .connected + } context.fireChannelActive() } diff --git a/Sources/PostgresNIO/New/PSQLPreparedStatement.swift b/Sources/PostgresNIO/New/PSQLPreparedStatement.swift index fbdfd868..5a9abf7e 100644 --- a/Sources/PostgresNIO/New/PSQLPreparedStatement.swift +++ b/Sources/PostgresNIO/New/PSQLPreparedStatement.swift @@ -7,7 +7,7 @@ struct PSQLPreparedStatement { let query: String /// The postgres connection the statement was prepared on - let connection: PSQLConnection + let connection: PostgresConnection /// The `RowDescription` to apply to all `DataRow`s when executing this `PSQLPreparedStatement` let rowDescription: RowDescription? diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 6dce981c..38e41a20 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -12,8 +12,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) XCTAssertNoThrow(try conn?.close().wait()) } @@ -22,7 +22,7 @@ final class IntegrationTests: XCTestCase { // authentication failure. try XCTSkipIf(env("POSTGRES_HOST_AUTH_METHOD") == "trust") - let config = PSQLConnection.Configuration( + let config = PostgresConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432, username: env("POSTGRES_USER") ?? "test_username", @@ -36,8 +36,8 @@ final class IntegrationTests: XCTestCase { var logger = Logger.psqlTest logger.logLevel = .info - var connection: PSQLConnection? - XCTAssertThrowsError(connection = try PSQLConnection.connect(configuration: config, logger: logger, on: eventLoopGroup.next()).wait()) { + var connection: PostgresConnection? + XCTAssertThrowsError(connection = try PostgresConnection.connect(connectionID: 1, configuration: config, logger: logger, on: eventLoopGroup.next()).wait()) { XCTAssertTrue($0 is PSQLError) } @@ -50,8 +50,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -68,8 +68,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -96,8 +96,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } for _ in 0..<1_000 { @@ -116,8 +116,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -134,8 +134,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -176,8 +176,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -194,8 +194,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -212,8 +212,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -230,8 +230,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -259,8 +259,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -285,8 +285,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } var stream: PSQLRowStream? @@ -310,8 +310,8 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - var conn: PSQLConnection? - XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } do { @@ -345,26 +345,3 @@ final class IntegrationTests: XCTestCase { } } } - - -extension PSQLConnection { - - static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - var logger = Logger(label: "psql.connection.test") - logger.logLevel = logLevel - let config = PSQLConnection.Configuration( - host: env("POSTGRES_HOSTNAME") ?? "localhost", - port: 5432, - username: env("POSTGRES_USER") ?? "test_username", - database: env("POSTGRES_DB") ?? "test_database", - password: env("POSTGRES_PASSWORD") ?? "test_password", - tls: .disable - ) - - return PSQLConnection.connect(configuration: config, logger: logger, on: eventLoop) - } -} - -extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { - static let `default`: Self = PostgresDecodingContext(jsonDecoder: JSONDecoder()) -} diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 7be9bab7..ee7ecaf0 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1064,7 +1064,7 @@ final class PostgresNIOTests: XCTestCase { func testRemoteClose() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) - XCTAssertNoThrow( try conn?.underlying.channel.close().wait() ) + XCTAssertNoThrow( try conn?.channel.close().wait() ) } // https://github.com/vapor/postgres-nio/issues/113 diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index 52e4f39c..01b830c4 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -173,9 +173,9 @@ class PSQLChannelHandlerTests: XCTestCase { username: String = "test", database: String = "postgres", password: String = "password", - tls: PSQLConnection.Configuration.TLS = .disable - ) -> PSQLConnection.Configuration { - PSQLConnection.Configuration( + tls: PostgresConnection.Configuration.TLS = .disable + ) -> PostgresConnection.Configuration { + PostgresConnection.Configuration( host: host, port: port, username: username, diff --git a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift index a0b68cea..260705c2 100644 --- a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift @@ -21,7 +21,7 @@ class PSQLConnectionTests: XCTestCase { return XCTFail("Could not get port number from temp started server") } - let config = PSQLConnection.Configuration( + let config = PostgresConnection.Configuration( host: "127.0.0.1", port: port, username: "postgres", @@ -33,7 +33,7 @@ class PSQLConnectionTests: XCTestCase { var logger = Logger.psqlTest logger.logLevel = .trace - XCTAssertThrowsError(try PSQLConnection.connect(configuration: config, logger: logger, on: eventLoopGroup.next()).wait()) { + XCTAssertThrowsError(try PostgresConnection.connect(connectionID: 1, configuration: config, logger: logger, on: eventLoopGroup.next()).wait()) { XCTAssertTrue($0 is PSQLError) } } From ef425af4833a36b7b153119a7eba67751d5cd76c Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 5 Mar 2022 21:10:34 +0100 Subject: [PATCH 076/292] Cleanup PostgresDecodable (#241) --- .../New/Data/Array+PostgresCodable.swift | 11 +- .../New/Data/Bool+PostgresCodable.swift | 40 +++-- .../New/Data/Bytes+PostgresCodable.swift | 24 ++- .../New/Data/Date+PostgresCodable.swift | 49 +++--- .../New/Data/Decimal+PostgresCodable.swift | 40 +++-- .../New/Data/Float+PostgresCodable.swift | 59 ++++--- .../New/Data/Int+PostgresCodable.swift | 160 +++++++++++------- .../New/Data/JSON+PostgresCodable.swift | 30 ++-- .../RawRepresentable+PostgresCodable.swift | 30 ++-- .../New/Data/String+PostgresCodable.swift | 18 +- .../New/Data/UUID+PostgresCodable.swift | 19 ++- Sources/PostgresNIO/New/PostgresCodable.swift | 17 +- .../New/Data/Array+PSQLCodableTests.swift | 18 +- .../New/Data/Bool+PSQLCodableTests.swift | 14 +- .../New/Data/Bytes+PSQLCodableTests.swift | 4 +- .../New/Data/Date+PSQLCodableTests.swift | 18 +- .../New/Data/Decimal+PSQLCodableTests.swift | 4 +- .../New/Data/Float+PSQLCodableTests.swift | 24 +-- .../New/Data/JSON+PSQLCodableTests.swift | 10 +- .../RawRepresentable+PSQLCodableTests.swift | 6 +- .../New/Data/String+PSQLCodableTests.swift | 8 +- .../New/Data/UUID+PSQLCodableTests.swift | 12 +- 22 files changed, 346 insertions(+), 269 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index 875361e1..91edc9a1 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -104,12 +104,12 @@ extension Array: PostgresEncodable where Element: PSQLArrayElement { } extension Array: PostgresDecodable where Element: PSQLArrayElement { - static func decode( + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Array { + ) throws { guard case .binary = format else { // currently we only support decoding arrays in binary format. throw PostgresCastingError.Code.failure @@ -124,7 +124,8 @@ extension Array: PostgresDecodable where Element: PSQLArrayElement { let elementType = PostgresDataType(element) guard isNotEmpty == 1 else { - return [] + self = [] + return } guard let (expectedArrayCount, dimensions) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32).self), @@ -146,12 +147,12 @@ extension Array: PostgresDecodable where Element: PSQLArrayElement { throw PostgresCastingError.Code.failure } - let element = try Element.decode(from: &elementBuffer, type: elementType, format: format, context: context) + let element = try Element.init(from: &elementBuffer, type: elementType, format: format, context: context) result.append(element) } - return result + self = result } } diff --git a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift index 9d9120b8..88609d13 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -1,35 +1,27 @@ import NIOCore -extension Bool: PostgresCodable { - var psqlType: PostgresDataType { - .bool - } - - var psqlFormat: PostgresFormat { - .binary - } - - static func decode( +extension Bool: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { guard type == .bool else { throw PostgresCastingError.Code.typeMismatch } - + switch format { case .binary: guard buffer.readableBytes == 1 else { throw PostgresCastingError.Code.failure } - + switch buffer.readInteger(as: UInt8.self) { case .some(0): - return false + self = false case .some(1): - return true + self = true default: throw PostgresCastingError.Code.failure } @@ -37,17 +29,27 @@ extension Bool: PostgresCodable { guard buffer.readableBytes == 1 else { throw PostgresCastingError.Code.failure } - + switch buffer.readInteger(as: UInt8.self) { case .some(UInt8(ascii: "f")): - return false + self = false case .some(UInt8(ascii: "t")): - return true + self = true default: throw PostgresCastingError.Code.failure } } } +} + +extension Bool: PostgresEncodable { + var psqlType: PostgresDataType { + .bool + } + + var psqlFormat: PostgresFormat { + .binary + } func encode( into byteBuffer: inout ByteBuffer, @@ -56,3 +58,5 @@ extension Bool: PostgresCodable { byteBuffer.writeInteger(self ? 1 : 0, as: UInt8.self) } } + +extension Bool: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift index 1c98948f..168d9c69 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift @@ -19,7 +19,7 @@ extension PostgresEncodable where Self: Sequence, Self.Element == UInt8 { } } -extension ByteBuffer: PostgresCodable { +extension ByteBuffer: PostgresEncodable { var psqlType: PostgresDataType { .bytea } @@ -35,18 +35,22 @@ extension ByteBuffer: PostgresCodable { var copyOfSelf = self // dirty hack byteBuffer.writeBuffer(©OfSelf) } +} - static func decode( +extension ByteBuffer: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { - return buffer + ) { + self = buffer } } -extension Data: PostgresCodable { +extension ByteBuffer: PostgresCodable {} + +extension Data: PostgresEncodable { var psqlType: PostgresDataType { .bytea } @@ -61,13 +65,17 @@ extension Data: PostgresCodable { ) { byteBuffer.writeBytes(self) } +} - static func decode( +extension Data: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { - return buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! + ) { + self = buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! } } + +extension Data: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift index cb440367..8c164f1c 100644 --- a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.Date -extension Date: PostgresCodable { +extension Date: PostgresEncodable { var psqlType: PostgresDataType { .timestamptz } @@ -10,44 +10,47 @@ extension Date: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + let seconds = self.timeIntervalSince(Self._psqlDateStart) * Double(Self._microsecondsPerSecond) + byteBuffer.writeInteger(Int64(seconds)) + } + + // MARK: Private Constants + + private static let _microsecondsPerSecond: Int64 = 1_000_000 + private static let _secondsInDay: Int64 = 24 * 60 * 60 + + /// values are stored as seconds before or after midnight 2000-01-01 + private static let _psqlDateStart = Date(timeIntervalSince1970: 946_684_800) +} + +extension Date: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch type { case .timestamp, .timestamptz: guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { throw PostgresCastingError.Code.failure } - let seconds = Double(microseconds) / Double(_microsecondsPerSecond) - return Date(timeInterval: seconds, since: _psqlDateStart) + let seconds = Double(microseconds) / Double(Self._microsecondsPerSecond) + self = Date(timeInterval: seconds, since: Self._psqlDateStart) case .date: guard buffer.readableBytes == 4, let days = buffer.readInteger(as: Int32.self) else { throw PostgresCastingError.Code.failure } - let seconds = Int64(days) * _secondsInDay - return Date(timeInterval: Double(seconds), since: _psqlDateStart) + let seconds = Int64(days) * Self._secondsInDay + self = Date(timeInterval: Double(seconds), since: Self._psqlDateStart) default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - let seconds = self.timeIntervalSince(Self._psqlDateStart) * Double(Self._microsecondsPerSecond) - byteBuffer.writeInteger(Int64(seconds)) - } - - // MARK: Private Constants - - private static let _microsecondsPerSecond: Int64 = 1_000_000 - private static let _secondsInDay: Int64 = 24 * 60 * 60 - - /// values are stored as seconds before or after midnight 2000-01-01 - private static let _psqlDateStart = Date(timeIntervalSince1970: 946_684_800) } +extension Date: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift index 9159b311..e80da7be 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.Decimal -extension Decimal: PostgresCodable { +extension Decimal: PostgresEncodable { var psqlType: PostgresDataType { .numeric } @@ -10,38 +10,42 @@ extension Decimal: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + let numeric = PostgresNumeric(decimal: self) + byteBuffer.writeInteger(numeric.ndigits) + byteBuffer.writeInteger(numeric.weight) + byteBuffer.writeInteger(numeric.sign) + byteBuffer.writeInteger(numeric.dscale) + var value = numeric.value + byteBuffer.writeBuffer(&value) + } +} + +extension Decimal: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .numeric): guard let numeric = PostgresNumeric(buffer: &buffer) else { throw PostgresCastingError.Code.failure } - return numeric.decimal + self = numeric.decimal case (.text, .numeric): guard let string = buffer.readString(length: buffer.readableBytes), let value = Decimal(string: string) else { throw PostgresCastingError.Code.failure } - return value + self = value default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - let numeric = PostgresNumeric(decimal: self) - byteBuffer.writeInteger(numeric.ndigits) - byteBuffer.writeInteger(numeric.weight) - byteBuffer.writeInteger(numeric.sign) - byteBuffer.writeInteger(numeric.dscale) - var value = numeric.value - byteBuffer.writeBuffer(&value) - } } + +extension Decimal: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift index 94b70820..1a39be18 100644 --- a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -1,6 +1,6 @@ import NIOCore -extension Float: PostgresCodable { +extension Float: PostgresEncodable { var psqlType: PostgresDataType { .float4 } @@ -9,42 +9,46 @@ extension Float: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.psqlWriteFloat(self) + } +} + +extension Float: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { throw PostgresCastingError.Code.failure } - return float + self = float case (.binary, .float8): guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { throw PostgresCastingError.Code.failure } - return Float(double) + self = Float(double) case (.text, .float4), (.text, .float8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Float(string) else { throw PostgresCastingError.Code.failure } - return value + self = value default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - byteBuffer.psqlWriteFloat(self) - } } -extension Double: PostgresCodable { +extension Float: PostgresCodable {} + +extension Double: PostgresEncodable { var psqlType: PostgresDataType { .float8 } @@ -53,38 +57,41 @@ extension Double: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.psqlWriteDouble(self) + } +} + +extension Double: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { throw PostgresCastingError.Code.failure } - return Double(float) + self = Double(float) case (.binary, .float8): guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { throw PostgresCastingError.Code.failure } - return double + self = double case (.text, .float4), (.text, .float8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Double(string) else { throw PostgresCastingError.Code.failure } - return value + self = value default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - byteBuffer.psqlWriteDouble(self) - } } +extension Double: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift index 6d980a40..e399a406 100644 --- a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -1,6 +1,8 @@ import NIOCore -extension UInt8: PostgresCodable { +// MARK: UInt8 + +extension UInt8: PostgresEncodable { var psqlType: PostgresDataType { .char } @@ -9,33 +11,39 @@ extension UInt8: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: UInt8.self) + } +} + +extension UInt8: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch type { case .bpchar, .char: guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { throw PostgresCastingError.Code.failure } - - return value + + self = value default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - byteBuffer.writeInteger(self, as: UInt8.self) - } } -extension Int16: PostgresCodable { +extension UInt8: PostgresCodable {} + +// MARK: Int16 + +extension Int16: PostgresEncodable { var psqlType: PostgresDataType { .int2 @@ -45,37 +53,43 @@ extension Int16: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: Int16.self) + } +} + +extension Int16: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { throw PostgresCastingError.Code.failure } - return value + self = value case (.text, .int2): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int16(string) else { throw PostgresCastingError.Code.failure } - return value + self = value default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - byteBuffer.writeInteger(self, as: Int16.self) - } } -extension Int32: PostgresCodable { +extension Int16: PostgresCodable {} + +// MARK: Int32 + +extension Int32: PostgresEncodable { var psqlType: PostgresDataType { .int4 } @@ -84,42 +98,48 @@ extension Int32: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: Int32.self) + } +} + +extension Int32: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { throw PostgresCastingError.Code.failure } - return Int32(value) + self = Int32(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { throw PostgresCastingError.Code.failure } - return Int32(value) + self = Int32(value) case (.text, .int2), (.text, .int4): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int32(string) else { throw PostgresCastingError.Code.failure } - return value + self = value default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - byteBuffer.writeInteger(self, as: Int32.self) - } } -extension Int64: PostgresCodable { +extension Int32: PostgresCodable {} + +// MARK: Int64 + +extension Int64: PostgresEncodable { var psqlType: PostgresDataType { .int8 } @@ -128,47 +148,53 @@ extension Int64: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: Int64.self) + } +} + +extension Int64: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { throw PostgresCastingError.Code.failure } - return Int64(value) + self = Int64(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { throw PostgresCastingError.Code.failure } - return Int64(value) + self = Int64(value) case (.binary, .int8): guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int64.self) else { throw PostgresCastingError.Code.failure } - return value + self = value case (.text, .int2), (.text, .int4), (.text, .int8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int64(string) else { throw PostgresCastingError.Code.failure } - return value + self = value default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - byteBuffer.writeInteger(self, as: Int64.self) - } } -extension Int: PostgresCodable { +extension Int64: PostgresCodable {} + +// MARK: Int + +extension Int: PostgresEncodable { var psqlType: PostgresDataType { switch self.bitWidth { case Int32.bitWidth: @@ -184,42 +210,46 @@ extension Int: PostgresCodable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + byteBuffer.writeInteger(self, as: Int.self) + } +} + +extension Int: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { throw PostgresCastingError.Code.failure } - return Int(value) + self = Int(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { throw PostgresCastingError.Code.failure } - return Int(value) + self = Int(value) case (.binary, .int8) where Int.bitWidth == 64: guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int.self) else { throw PostgresCastingError.Code.failure } - return value + self = value case (.text, .int2), (.text, .int4), (.text, .int8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int(string) else { throw PostgresCastingError.Code.failure } - return value + self = value default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) { - byteBuffer.writeInteger(self, as: Int.self) - } } + +extension Int: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift index 9e5aeb18..a506c2d6 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift @@ -5,7 +5,7 @@ import class Foundation.JSONDecoder private let JSONBVersionByte: UInt8 = 0x01 -extension PostgresCodable where Self: Codable { +extension PostgresEncodable where Self: Codable { var psqlType: PostgresDataType { .jsonb } @@ -14,30 +14,34 @@ extension PostgresCodable where Self: Codable { .binary } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws { + byteBuffer.writeInteger(JSONBVersionByte) + try context.jsonEncoder.encode(self, into: &byteBuffer) + } +} + +extension PostgresDecodable where Self: Codable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .jsonb): guard JSONBVersionByte == buffer.readInteger(as: UInt8.self) else { throw PostgresCastingError.Code.failure } - return try context.jsonDecoder.decode(Self.self, from: buffer) + self = try context.jsonDecoder.decode(Self.self, from: buffer) case (.binary, .json), (.text, .jsonb), (.text, .json): - return try context.jsonDecoder.decode(Self.self, from: buffer) + self = try context.jsonDecoder.decode(Self.self, from: buffer) default: throw PostgresCastingError.Code.typeMismatch } } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) throws { - byteBuffer.writeInteger(JSONBVersionByte) - try context.jsonEncoder.encode(self, into: &byteBuffer) - } } + +extension PostgresCodable where Self: Codable {} diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift index d05b179e..c64da931 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -1,6 +1,6 @@ import NIOCore -extension PostgresCodable where Self: RawRepresentable, RawValue: PostgresCodable { +extension PostgresEncodable where Self: RawRepresentable, RawValue: PostgresEncodable { var psqlType: PostgresDataType { self.rawValue.psqlType } @@ -9,24 +9,28 @@ extension PostgresCodable where Self: RawRepresentable, RawValue: PostgresCodabl self.rawValue.psqlFormat } - static func decode( + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws { + try rawValue.encode(into: &byteBuffer, context: context) + } +} + +extension PostgresDecodable where Self: RawRepresentable, RawValue: PostgresDecodable, RawValue._DecodableType == RawValue { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { - guard let rawValue = try? RawValue.decode(from: &buffer, type: type, format: format, context: context), + ) throws { + guard let rawValue = try? RawValue(from: &buffer, type: type, format: format, context: context), let selfValue = Self.init(rawValue: rawValue) else { throw PostgresCastingError.Code.failure } - - return selfValue - } - - func encode( - into byteBuffer: inout ByteBuffer, - context: PostgresEncodingContext - ) throws { - try rawValue.encode(into: &byteBuffer, context: context) + + self = selfValue } } + +extension PostgresCodable where Self: RawRepresentable, RawValue: PostgresCodable, RawValue._DecodableType == RawValue {} diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 538e2db5..56080540 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.UUID -extension String: PostgresCodable { +extension String: PostgresEncodable { var psqlType: PostgresDataType { .text } @@ -16,27 +16,31 @@ extension String: PostgresCodable { ) { byteBuffer.writeString(self) } - - static func decode( +} + +extension String: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (_, .varchar), (_, .text), (_, .name): // we can force unwrap here, since this method only fails if there are not enough // bytes available. - return buffer.readString(length: buffer.readableBytes)! + self = buffer.readString(length: buffer.readableBytes)! case (_, .uuid): - guard let uuid = try? UUID.decode(from: &buffer, type: .uuid, format: format, context: context) else { + guard let uuid = try? UUID(from: &buffer, type: .uuid, format: format, context: context) else { throw PostgresCastingError.Code.failure } - return uuid.uuidString + self = uuid.uuidString default: throw PostgresCastingError.Code.typeMismatch } } } + +extension String: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index 95e21dd3..2ec813bd 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -2,8 +2,7 @@ import NIOCore import struct Foundation.UUID import typealias Foundation.uuid_t -extension UUID: PostgresCodable { - +extension UUID: PostgresEncodable { var psqlType: PostgresDataType { .uuid } @@ -24,19 +23,21 @@ extension UUID: PostgresCodable { uuid.12, uuid.13, uuid.14, uuid.15, ]) } - - static func decode( +} + +extension UUID: PostgresDecodable { + init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self { + ) throws { switch (format, type) { case (.binary, .uuid): guard let uuid = buffer.readUUID() else { throw PostgresCastingError.Code.failure } - return uuid + self = uuid case (.binary, .varchar), (.binary, .text), (.text, .uuid), @@ -45,17 +46,19 @@ extension UUID: PostgresCodable { guard buffer.readableBytes == 36 else { throw PostgresCastingError.Code.failure } - + guard let uuid = buffer.readString(length: 36).flatMap({ UUID(uuidString: $0) }) else { throw PostgresCastingError.Code.failure } - return uuid + self = uuid default: throw PostgresCastingError.Code.typeMismatch } } } +extension UUID: PostgresCodable {} + extension ByteBuffer { mutating func readUUID() -> UUID? { guard self.readableBytes >= MemoryLayout.size else { diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 2ae01e76..8d84b283 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -24,7 +24,7 @@ protocol PostgresDecodable { /// String? should be PostgresDecodable, String?? should not be PostgresDecodable associatedtype _DecodableType: PostgresDecodable = Self - /// Decode an entity from the `byteBuffer` in postgres wire format + /// Create an entity from the `byteBuffer` in postgres wire format /// /// - Parameters: /// - byteBuffer: A `ByteBuffer` to decode. The byteBuffer is sliced in such a way that it is expected @@ -35,12 +35,12 @@ protocol PostgresDecodable { /// - context: A `PSQLDecodingContext` providing context for decoding. This includes a `JSONDecoder` /// to use when decoding json and metadata to create better errors. /// - Returns: A decoded object - static func decode( + init( from byteBuffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext - ) throws -> Self + ) throws /// Decode an entity from the `byteBuffer` in postgres wire format. This method has a default implementation and /// is only overwritten for `Optional`s. Other than in the @@ -63,7 +63,7 @@ extension PostgresDecodable { guard var buffer = byteBuffer else { throw PostgresCastingError.Code.missingData } - return try self.decode(from: &buffer, type: type, format: format, context: context) + return try self.init(from: &buffer, type: type, format: format, context: context) } } @@ -116,7 +116,12 @@ extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped._DecodableType == Wrapped { typealias _DecodableType = Wrapped - static func decode(from byteBuffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, context: PostgresDecodingContext) throws -> Optional { + init( + from byteBuffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { preconditionFailure("This should not be called") } @@ -128,7 +133,7 @@ extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped. ) throws -> Optional { switch byteBuffer { case .some(var buffer): - return try Wrapped.decode(from: &buffer, type: type, format: format, context: context) + return try Wrapped(from: &buffer, type: type, format: format, context: context) case .none: return .none } diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index 62a6629f..a7c40550 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -64,7 +64,7 @@ class Array_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(try values.encode(into: &buffer, context: .default)) var result: [String]? - XCTAssertNoThrow(result = try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) + XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) XCTAssertEqual(values, result) } @@ -75,7 +75,7 @@ class Array_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(try values.encode(into: &buffer, context: .default)) var result: [String]? - XCTAssertNoThrow(result = try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) + XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) XCTAssertEqual(values, result) } @@ -85,7 +85,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(0)) buffer.writeInteger(String.psqlArrayElementType.rawValue) - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -96,7 +96,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(1)) // invalid value, must always be 0 buffer.writeInteger(String.psqlArrayElementType.rawValue) - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -106,7 +106,7 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -119,7 +119,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(-123)) // expected element count buffer.writeInteger(Int32(1)) // dimensions... must be one - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -132,7 +132,7 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(1)) // expected element count buffer.writeInteger(Int32(2)) // dimensions... must be one - XCTAssertThrowsError(try [String].decode(from: &buffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -146,7 +146,7 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 - XCTAssertThrowsError(try [String].decode(from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertThrowsError(try [String](from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } @@ -159,7 +159,7 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 unexpectedEndInElementBuffer.writeString("Hello World") // only 11 bytes, 12 needed! - XCTAssertThrowsError(try [String].decode(from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default)) { + XCTAssertThrowsError(try [String](from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index f9c8103b..8f77bcea 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -17,7 +17,7 @@ class Bool_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) var result: Bool? - XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) XCTAssertEqual(value, result) } @@ -32,7 +32,7 @@ class Bool_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 0) var result: Bool? - XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) XCTAssertEqual(value, result) } @@ -40,7 +40,7 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64(1)) - XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .default)) { + XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -49,7 +49,7 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .binary, context: .default)) { + XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -63,7 +63,7 @@ class Bool_PSQLCodableTests: XCTestCase { buffer.writeInteger(UInt8(ascii: "t")) var result: Bool? - XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .text, context: .default)) + XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) XCTAssertEqual(value, result) } @@ -74,7 +74,7 @@ class Bool_PSQLCodableTests: XCTestCase { buffer.writeInteger(UInt8(ascii: "f")) var result: Bool? - XCTAssertNoThrow(result = try Bool.decode(from: &buffer, type: .bool, format: .text, context: .default)) + XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) XCTAssertEqual(value, result) } @@ -82,7 +82,7 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - XCTAssertThrowsError(try Bool.decode(from: &buffer, type: .bool, format: .text, context: .default)) { + XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .text, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index 1dee1e06..b67c0b5e 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -12,7 +12,7 @@ class Bytes_PSQLCodableTests: XCTestCase { XCTAssertEqual(data.psqlType, .bytea) var result: Data? - XCTAssertNoThrow(result = try Data.decode(from: &buffer, type: .bytea, format: .binary, context: .default)) + result = Data(from: &buffer, type: .bytea, format: .binary, context: .default) XCTAssertEqual(data, result) } @@ -24,7 +24,7 @@ class Bytes_PSQLCodableTests: XCTestCase { XCTAssertEqual(bytes.psqlType, .bytea) var result: ByteBuffer? - XCTAssertNoThrow(result = try ByteBuffer.decode(from: &buffer, type: .bytea, format: .binary, context: .default)) + result = ByteBuffer(from: &buffer, type: .bytea, format: .binary, context: .default) XCTAssertEqual(bytes, result) } diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 02bc4e97..38ce1d04 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -13,7 +13,7 @@ class Date_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 8) var result: Date? - XCTAssertNoThrow(result = try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) XCTAssertEqual(value, result) } @@ -22,7 +22,7 @@ class Date_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) var result: Date? - XCTAssertNoThrow(result = try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) XCTAssertNotNil(result) } @@ -31,7 +31,7 @@ class Date_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) - XCTAssertThrowsError(try Date.decode(from: &buffer, type: .timestamptz, format: .binary, context: .default)) { + XCTAssertThrowsError(try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -41,14 +41,14 @@ class Date_PSQLCodableTests: XCTestCase { firstDateBuffer.writeInteger(Int32.min) var firstDate: Date? - XCTAssertNoThrow(firstDate = try Date.decode(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) + XCTAssertNoThrow(firstDate = try Date(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(firstDate) var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) var lastDate: Date? - XCTAssertNoThrow(lastDate = try Date.decode(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) + XCTAssertNoThrow(lastDate = try Date(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(lastDate) } @@ -57,14 +57,14 @@ class Date_PSQLCodableTests: XCTestCase { firstDateBuffer.writeInteger(Int32.min) var firstDate: Date? - XCTAssertNoThrow(firstDate = try Date.decode(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) + XCTAssertNoThrow(firstDate = try Date(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(firstDate) var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) var lastDate: Date? - XCTAssertNoThrow(lastDate = try Date.decode(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) + XCTAssertNoThrow(lastDate = try Date(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(lastDate) } @@ -72,7 +72,7 @@ class Date_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - XCTAssertThrowsError(try Date.decode(from: &buffer, type: .date, format: .binary, context: .default)) { + XCTAssertThrowsError(try Date(from: &buffer, type: .date, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -81,7 +81,7 @@ class Date_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - XCTAssertThrowsError(try Date.decode(from: &buffer, type: .int8, format: .binary, context: .default)) { + XCTAssertThrowsError(try Date(from: &buffer, type: .int8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift index 5e385de9..2898f998 100644 --- a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift @@ -13,7 +13,7 @@ class Decimal_PSQLCodableTests: XCTestCase { XCTAssertEqual(value.psqlType, .numeric) var result: Decimal? - XCTAssertNoThrow(result = try Decimal.decode(from: &buffer, type: .numeric, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Decimal(from: &buffer, type: .numeric, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -22,7 +22,7 @@ class Decimal_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - XCTAssertThrowsError(try Decimal.decode(from: &buffer, type: .int8, format: .binary, context: .default)) { + XCTAssertThrowsError(try Decimal(from: &buffer, type: .int8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift index 5bd6eacb..3cac7e6f 100644 --- a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift @@ -14,7 +14,7 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 8) var result: Double? - XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Double(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -29,7 +29,7 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 4) var result: Float? - XCTAssertNoThrow(result = try Float.decode(from: &buffer, type: .float4, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Float(from: &buffer, type: .float4, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -43,7 +43,7 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 8) var result: Double? - XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Double(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(result?.isNaN, true) } @@ -56,7 +56,7 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 8) var result: Double? - XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float8, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Double(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(result?.isInfinite, true) } @@ -70,7 +70,7 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 4) var result: Double? - XCTAssertNoThrow(result = try Double.decode(from: &buffer, type: .float4, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Double(from: &buffer, type: .float4, format: .binary, context: .default)) XCTAssertEqual(result, Double(value)) } } @@ -85,7 +85,7 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 8) var result: Float? - XCTAssertNoThrow(result = try Float.decode(from: &buffer, type: .float8, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Float(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(result, Float(value)) } } @@ -97,22 +97,22 @@ class Float_PSQLCodableTests: XCTestCase { fourByteBuffer.writeInteger(Int32(0)) var toLongBuffer1 = eightByteBuffer - XCTAssertThrowsError(try Double.decode(from: &toLongBuffer1, type: .float4, format: .binary, context: .default)) { + XCTAssertThrowsError(try Double(from: &toLongBuffer1, type: .float4, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } var toLongBuffer2 = eightByteBuffer - XCTAssertThrowsError(try Float.decode(from: &toLongBuffer2, type: .float4, format: .binary, context: .default)) { + XCTAssertThrowsError(try Float(from: &toLongBuffer2, type: .float4, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } var toShortBuffer1 = fourByteBuffer - XCTAssertThrowsError(try Double.decode(from: &toShortBuffer1, type: .float8, format: .binary, context: .default)) { + XCTAssertThrowsError(try Double(from: &toShortBuffer1, type: .float8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } var toShortBuffer2 = fourByteBuffer - XCTAssertThrowsError(try Float.decode(from: &toShortBuffer2, type: .float8, format: .binary, context: .default)) { + XCTAssertThrowsError(try Float(from: &toShortBuffer2, type: .float8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -122,12 +122,12 @@ class Float_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int64(0)) var copy1 = buffer - XCTAssertThrowsError(try Double.decode(from: ©1, type: .int8, format: .binary, context: .default)) { + XCTAssertThrowsError(try Double(from: ©1, type: .int8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } var copy2 = buffer - XCTAssertThrowsError(try Float.decode(from: ©2, type: .int8, format: .binary, context: .default)) { + XCTAssertThrowsError(try Float(from: ©2, type: .int8, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 04085168..46563973 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -22,7 +22,7 @@ class JSON_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) var result: Hello? - XCTAssertNoThrow(result = try Hello.decode(from: &buffer, type: .jsonb, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Hello(from: &buffer, type: .jsonb, format: .binary, context: .default)) XCTAssertEqual(result, hello) } @@ -31,7 +31,7 @@ class JSON_PSQLCodableTests: XCTestCase { buffer.writeString(#"{"hello":"world"}"#) var result: Hello? - XCTAssertNoThrow(result = try Hello.decode(from: &buffer, type: .json, format: .binary, context: .default)) + XCTAssertNoThrow(result = try Hello(from: &buffer, type: .json, format: .binary, context: .default)) XCTAssertEqual(result, Hello(name: "world")) } @@ -45,7 +45,7 @@ class JSON_PSQLCodableTests: XCTestCase { for (format, dataType) in combinations { var loopBuffer = buffer var result: Hello? - XCTAssertNoThrow(result = try Hello.decode(from: &loopBuffer, type: dataType, format: format, context: .default)) + XCTAssertNoThrow(result = try Hello(from: &loopBuffer, type: dataType, format: format, context: .default)) XCTAssertEqual(result, Hello(name: "world")) } } @@ -54,7 +54,7 @@ class JSON_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - XCTAssertThrowsError(try Hello.decode(from: &buffer, type: .jsonb, format: .binary, context: .default)) { + XCTAssertThrowsError(try Hello(from: &buffer, type: .jsonb, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -63,7 +63,7 @@ class JSON_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - XCTAssertThrowsError(try Hello.decode(from: &buffer, type: .text, format: .binary, context: .default)) { + XCTAssertThrowsError(try Hello(from: &buffer, type: .text, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift index d017d00e..99a250aa 100644 --- a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -20,7 +20,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 2) var result: MyRawRepresentable? - XCTAssertNoThrow(result = try MyRawRepresentable.decode(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .default)) + XCTAssertNoThrow(result = try MyRawRepresentable(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -29,7 +29,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int16(4)) // out of bounds - XCTAssertThrowsError(try MyRawRepresentable.decode(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .default)) { + XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -38,7 +38,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // out of bounds - XCTAssertThrowsError(try MyRawRepresentable.decode(from: &buffer, type: Int32.psqlArrayElementType, format: .binary, context: .default)) { + XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int32.psqlArrayElementType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index e4c62704..9d2937e4 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -26,7 +26,7 @@ class String_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer var result: String? - XCTAssertNoThrow(result = try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .default)) + XCTAssertNoThrow(result = try String(from: &loopBuffer, type: dataType, format: .binary, context: .default)) XCTAssertEqual(result, expected) } } @@ -37,7 +37,7 @@ class String_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer - XCTAssertThrowsError(try String.decode(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { + XCTAssertThrowsError(try String(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } @@ -49,7 +49,7 @@ class String_PSQLCodableTests: XCTestCase { uuid.encode(into: &buffer, context: .default) var decoded: String? - XCTAssertNoThrow(decoded = try String.decode(from: &buffer, type: .uuid, format: .binary, context: .default)) + XCTAssertNoThrow(decoded = try String(from: &buffer, type: .uuid, format: .binary, context: .default)) XCTAssertEqual(decoded, uuid.uuidString) } @@ -60,7 +60,7 @@ class String_PSQLCodableTests: XCTestCase { // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - XCTAssertThrowsError(try String.decode(from: &buffer, type: .uuid, format: .binary, context: .default)) { + XCTAssertThrowsError(try String(from: &buffer, type: .uuid, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 840b8531..1df8001b 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -34,7 +34,7 @@ class UUID_PSQLCodableTests: XCTestCase { XCTAssertEqual(byteIterator.next(), uuid.uuid.15) var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .default)) + XCTAssertNoThrow(decoded = try UUID(from: &buffer, type: .uuid, format: .binary, context: .default)) XCTAssertEqual(decoded, uuid) } } @@ -57,7 +57,7 @@ class UUID_PSQLCodableTests: XCTestCase { for (format, dataType) in options { var loopBuffer = lowercaseBuffer var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, format: format, context: .default)) + XCTAssertNoThrow(decoded = try UUID(from: &loopBuffer, type: dataType, format: format, context: .default)) XCTAssertEqual(decoded, uuid) } @@ -68,7 +68,7 @@ class UUID_PSQLCodableTests: XCTestCase { for (format, dataType) in options { var loopBuffer = uppercaseBuffer var decoded: UUID? - XCTAssertNoThrow(decoded = try UUID.decode(from: &loopBuffer, type: dataType, format: format, context: .default)) + XCTAssertNoThrow(decoded = try UUID(from: &loopBuffer, type: dataType, format: format, context: .default)) XCTAssertEqual(decoded, uuid) } } @@ -82,7 +82,7 @@ class UUID_PSQLCodableTests: XCTestCase { // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - XCTAssertThrowsError(try UUID.decode(from: &buffer, type: .uuid, format: .binary, context: .default)) { error in + XCTAssertThrowsError(try UUID(from: &buffer, type: .uuid, format: .binary, context: .default)) { error in XCTAssertEqual(error as? PostgresCastingError.Code, .failure) } } @@ -98,7 +98,7 @@ class UUID_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var loopBuffer = buffer - XCTAssertThrowsError(try UUID.decode(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { + XCTAssertThrowsError(try UUID(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -113,7 +113,7 @@ class UUID_PSQLCodableTests: XCTestCase { for dataType in dataTypes { var copy = buffer - XCTAssertThrowsError(try UUID.decode(from: ©, type: dataType, format: .binary, context: .default)) { + XCTAssertThrowsError(try UUID(from: ©, type: dataType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) } } From 495dec9a1621ac6556963e4e093d5384ccee9aa2 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 5 Mar 2022 21:30:39 +0100 Subject: [PATCH 077/292] Make PostgresEncodingContext & PostgresDecodingContext public (#243) --- Sources/PostgresNIO/New/PostgresCodable.swift | 37 ++++++++++++++----- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 8d84b283..b197fdd6 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -89,28 +89,47 @@ extension PostgresEncodable { } } -struct PostgresEncodingContext { - let jsonEncoder: JSONEncoder +/// A context that is passed to Swift objects that are encoded into the Postgres wire format. Used +/// to pass further information to the encoding method. +public struct PostgresEncodingContext { + /// A ``PostgresJSONEncoder`` used to encode the object to json. + public var jsonEncoder: JSONEncoder - init(jsonEncoder: JSONEncoder) { + + /// Creates a ``PostgresEncodingContext`` with the given ``PostgresJSONEncoder``. In case you want + /// to use the a ``PostgresEncodingContext`` with an unconfigured Foundation `JSONEncoder` + /// you can use the ``default`` context instead. + /// + /// - Parameter jsonEncoder: A ``PostgresJSONEncoder`` to use when encoding objects to json + public init(jsonEncoder: JSONEncoder) { self.jsonEncoder = jsonEncoder } } extension PostgresEncodingContext where JSONEncoder == Foundation.JSONEncoder { - static let `default` = PostgresEncodingContext(jsonEncoder: JSONEncoder()) + /// A default ``PostgresEncodingContext`` that uses a Foundation `JSONEncoder`. + public static let `default` = PostgresEncodingContext(jsonEncoder: JSONEncoder()) } -struct PostgresDecodingContext { - let jsonDecoder: JSONDecoder - - init(jsonDecoder: JSONDecoder) { +/// A context that is passed to Swift objects that are decoded from the Postgres wire format. Used +/// to pass further information to the decoding method. +public struct PostgresDecodingContext { + /// A ``PostgresJSONDecoder`` used to decode the object from json. + public var jsonDecoder: JSONDecoder + + /// Creates a ``PostgresDecodingContext`` with the given ``PostgresJSONDecoder``. In case you want + /// to use the a ``PostgresDecodingContext`` with an unconfigured Foundation `JSONDecoder` + /// you can use the ``default`` context instead. + /// + /// - Parameter jsonDecoder: A ``PostgresJSONDecoder`` to use when decoding objects from json + public init(jsonDecoder: JSONDecoder) { self.jsonDecoder = jsonDecoder } } extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { - static let `default` = PostgresDecodingContext(jsonDecoder: Foundation.JSONDecoder()) + /// A default ``PostgresDecodingContext`` that uses a Foundation `JSONDecoder`. + public static let `default` = PostgresDecodingContext(jsonDecoder: Foundation.JSONDecoder()) } extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped._DecodableType == Wrapped { From c54c00656f5d820a33abee728d1f86b1fbc30b6b Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sun, 6 Mar 2022 10:02:53 +0100 Subject: [PATCH 078/292] Add Sendable checking (#242) --- .../PostgresNIO/Connection/PostgresConnection.swift | 4 ++++ Sources/PostgresNIO/Data/PostgresData.swift | 8 ++++++++ Sources/PostgresNIO/Data/PostgresDataType.swift | 9 +++++++++ Sources/PostgresNIO/Data/PostgresRow.swift | 12 ++++++++++++ .../PostgresNIO/New/Messages/RowDescription.swift | 4 ++++ Sources/PostgresNIO/New/PSQLRowStream.swift | 5 +++++ Sources/PostgresNIO/New/PostgresCell.swift | 8 ++++++++ 7 files changed, 50 insertions(+) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 58eb621e..20a41af7 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -566,3 +566,7 @@ private extension String { } } } + +#if swift(>=5.6) +extension PostgresConnection: @unchecked Sendable {} +#endif diff --git a/Sources/PostgresNIO/Data/PostgresData.swift b/Sources/PostgresNIO/Data/PostgresData.swift index 96ac7023..16d4b3ee 100644 --- a/Sources/PostgresNIO/Data/PostgresData.swift +++ b/Sources/PostgresNIO/Data/PostgresData.swift @@ -1,4 +1,8 @@ +#if swift(>=5.6) +@preconcurrency import NIOCore +#else import NIOCore +#endif import Foundation public struct PostgresData: CustomStringConvertible, CustomDebugStringConvertible { @@ -112,3 +116,7 @@ extension PostgresData: PostgresDataConvertible { return self } } + +#if swift(>=5.6) +extension PostgresData: Sendable {} +#endif diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index 3daa85c5..55f529dc 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -17,6 +17,11 @@ extension PostgresFormat: CustomStringConvertible { } } +#if swift(>=5.6) +extension PostgresFormat: Sendable {} +#endif + + // TODO: The Codable conformance does not make any sense. Let's remove this with next major break. extension PostgresFormat: Codable {} @@ -233,6 +238,10 @@ public struct PostgresDataType: RawRepresentable, Hashable, CustomStringConverti } } +#if swift(>=5.6) +extension PostgresDataType: Sendable {} +#endif + // TODO: The Codable conformance does not make any sense. Let's remove this with next major break. extension PostgresDataType: Codable {} diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 3ac20c5e..83343812 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -1,4 +1,8 @@ +#if swift(>=5.6) +@preconcurrency import NIOCore +#else import NIOCore +#endif import class Foundation.JSONDecoder /// `PostgresRow` represents a single table row that is received from the server for a query or a prepared statement. @@ -311,3 +315,11 @@ extension PostgresRow: CustomStringConvertible { return row.description } } + +#if swift(>=5.6) +extension PostgresRow: Sendable {} + +extension PostgresRandomAccessRow: Sendable {} +#endif + + diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index de855e98..ba3eee7f 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -76,3 +76,7 @@ struct RowDescription: PostgresBackendMessage.PayloadDecodable, Equatable { return RowDescription(columns: result) } } + +#if swift(>=5.6) +extension RowDescription.Column: Sendable {} +#endif diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 2d0ec455..58730851 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -396,3 +396,8 @@ protocol PSQLRowsDataSource { func cancel(for stream: PSQLRowStream) } + +#if swift(>=5.6) +// Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop. +extension PSQLRowStream: @unchecked Sendable {} +#endif diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index a29eacd6..624e845d 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -1,4 +1,8 @@ +#if swift(>=5.6) +@preconcurrency import NIOCore +#else import NIOCore +#endif public struct PostgresCell: Equatable { public var bytes: ByteBuffer? @@ -49,3 +53,7 @@ extension PostgresCell { } } } + +#if swift(>=5.6) +extension PostgresCell: Sendable {} +#endif From 32120378171637358c3ddbde68df46203e1fae70 Mon Sep 17 00:00:00 2001 From: BennyDB <74614235+BennyDeBock@users.noreply.github.com> Date: Mon, 7 Mar 2022 15:15:58 +0100 Subject: [PATCH 079/292] Add project board workflow (#246) --- .github/workflows/projectboard.yml | 72 ++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 .github/workflows/projectboard.yml diff --git a/.github/workflows/projectboard.yml b/.github/workflows/projectboard.yml new file mode 100644 index 00000000..e4ff9c69 --- /dev/null +++ b/.github/workflows/projectboard.yml @@ -0,0 +1,72 @@ +name: first-issues-to-beginner-issues-project +on: + # Trigger when an issue gets labeled or deleted + issues: + types: [reopened, closed, labeled, unlabeled, assigned, unassigned] + +jobs: + manage_project_issues: + strategy: + fail-fast: false + matrix: + project: + - 'Beginner Issues' + runs-on: ubuntu-latest + if: contains(github.event.issue.labels.*.name, 'good first issue') + steps: + # When an issue that is open is labeled, unassigned or reopened without a assigned member + # create or move the card to "To do" + - name: Create or Update Project Card + if: | + github.event.action == 'labeled' || + github.event.action == 'reopened' || + github.event.action == 'unassigned' + uses: alex-page/github-project-automation-plus@v0.8.1 + with: + project: ${{ matrix.project }} + column: 'To do' + repo-token: ${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }} + + # When an issue that is open is assigned and has an assigned member + # create or move the card to "In progress" + - name: Assign Project Card + if: | + github.event.action == 'assigned' + uses: alex-page/github-project-automation-plus@v0.8.1 + with: + project: ${{ matrix.project }} + column: 'In progress' + repo-token: ${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }} + + # When an issue is closed with the good first issue tag + # Create or move the card to "Done" + - name: Close Project Card + if: | + github.event.action == 'closed' + uses: asmfnk/my-github-project-automation@v0.5.0 + with: + project: ${{ matrix.project }} + column: 'Done' + repo-token: ${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }} + + remove_project_issues: + strategy: + fail-fast: false + matrix: + project: + - 'Beginner Issues' + runs-on: ubuntu-latest + if: ${{ !contains(github.event.issue.labels.*.name, 'good first issue') }} + steps: + # When an issue has the tag 'good first issue' removed + # Remove the card from the board + - name: Remove Project Card + if: | + github.event.action == 'unlabeled' + uses: alex-page/github-project-automation-plus@v0.8.1 + with: + project: ${{ matrix.project }} + column: 'To do' + repo-token: ${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }} + action: delete + From 2998b1cf87d0285f6762bf43844e442e4ccc7b69 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 8 Mar 2022 20:20:27 +0100 Subject: [PATCH 080/292] New PostgresConnection connect API (#245) --- .../Connection/PostgresConnection.swift | 206 +++++++++++++----- .../PostgresNIO/New/PSQLChannelHandler.swift | 12 +- .../PSQLIntegrationTests.swift | 19 +- Tests/IntegrationTests/PostgresNIOTests.swift | 30 ++- Tests/IntegrationTests/Utilities.swift | 25 ++- .../New/PSQLChannelHandlerTests.swift | 13 +- .../New/PSQLConnectionTests.swift | 9 +- 7 files changed, 218 insertions(+), 96 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 20a41af7..d9f24117 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -4,23 +4,38 @@ import NIOSSL import Logging import NIOPosix +/// A Postgres connection. Use it to run queries against a Postgres server. public final class PostgresConnection { - typealias ID = Int - - struct Configuration { - struct Authentication { - var username: String - var database: String? = nil - var password: String? = nil - - init(username: String, password: String?, database: String?) { + /// A Postgres connection ID + public typealias ID = Int + + /// A configuration object for a connection + public struct Configuration { + /// A structure to configure the connection's authentication properties + public struct Authentication { + /// The username to connect with. + /// + /// - Default: postgres + public var username: String + + /// The database to open on the server + /// + /// - Default: `nil` + public var database: Optional + + /// The database user's password. + /// + /// - Default: `nil` + public var password: Optional + + public init(username: String, database: String?, password: String?) { self.username = username self.database = database self.password = password } } - struct TLS { + public struct TLS { enum Base { case disable case prefer(NIOSSLContext) @@ -33,44 +48,50 @@ public final class PostgresConnection { self.base = base } - static var disable: Self = Self.init(.disable) + /// Do not try to create a TLS connection to the server. + public static var disable: Self = Self.init(.disable) - static func prefer(_ sslContext: NIOSSLContext) -> Self { + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, create an insecure connection. + public static func prefer(_ sslContext: NIOSSLContext) -> Self { self.init(.prefer(sslContext)) } - static func require(_ sslContext: NIOSSLContext) -> Self { + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, fail the connection creation. + public static func require(_ sslContext: NIOSSLContext) -> Self { self.init(.require(sslContext)) } } - enum Connection { - case unresolved(host: String, port: Int) - case resolved(address: SocketAddress, serverName: String?) + public struct Connection { + /// The server to connect to + /// + /// - Default: localhost + public var host: String + + /// The server port to connect to. + /// + /// - Default: 5432 + public var port: Int + + public init(host: String, port: Int = 5432) { + self.host = host + self.port = port + } } - var connection: Connection + public var connection: Connection /// The authentication properties to send to the Postgres server during startup auth handshake - var authentication: Authentication? + public var authentication: Authentication - var tls: TLS + public var tls: TLS - init(host: String, - port: Int = 5432, - username: String, - database: String? = nil, - password: String? = nil, - tls: TLS = .disable - ) { - self.connection = .unresolved(host: host, port: port) - self.authentication = Authentication(username: username, password: password, database: database) - self.tls = tls - } - - init(connection: Connection, - authentication: Authentication?, - tls: TLS + public init( + connection: Connection, + authentication: Authentication, + tls: TLS ) { self.connection = connection self.authentication = authentication @@ -129,7 +150,7 @@ public final class PostgresConnection { assert(self.isClosed, "PostgresConnection deinitialized before being closed.") } - func start(configuration: Configuration) -> EventLoopFuture { + func start(configuration: InternalConfiguration) -> EventLoopFuture { // 1. configure handlers var configureSSLCallback: ((Channel) throws -> ())? = nil @@ -186,9 +207,32 @@ public final class PostgresConnection { } } + /// Create a new connection to a Postgres server + /// + /// - Parameters: + /// - eventLoop: The `EventLoop` the request shall be created on + /// - configuration: A ``Configuration`` that shall be used for the connection + /// - connectionID: An `Int` id, used for metadata logging + /// - logger: A logger to log background events into + /// - Returns: A SwiftNIO `EventLoopFuture` that will provide a ``PostgresConnection`` + /// at a later point in time. + public static func connect( + on eventLoop: EventLoop, + configuration: PostgresConnection.Configuration, + id connectionID: ID, + logger: Logger + ) -> EventLoopFuture { + self.connect( + connectionID: connectionID, + configuration: .init(configuration), + logger: logger, + on: eventLoop + ) + } + static func connect( connectionID: ID, - configuration: PostgresConnection.Configuration, + configuration: PostgresConnection.InternalConfiguration, logger: Logger, on eventLoop: EventLoop ) -> EventLoopFuture { @@ -286,6 +330,9 @@ public final class PostgresConnection { } + /// Closes the connection to the server. + /// + /// - Returns: An EventLoopFuture that is succeeded once the connection is closed. public func close() -> EventLoopFuture { guard !self.isClosed else { return self.eventLoop.makeSucceededFuture(()) @@ -301,6 +348,10 @@ public final class PostgresConnection { extension PostgresConnection { static let idGenerator = NIOAtomic.makeAtomic(value: 0) + @available(*, deprecated, + message: "Use the new connect method that allows you to connect and authenticate in a single step", + renamed: "connect(on:configuration:id:logger:)" + ) public static func connect( to socketAddress: SocketAddress, tlsConfiguration: TLSConfiguration? = nil, @@ -319,7 +370,7 @@ extension PostgresConnection { } return tlsFuture.flatMap { tls in - let configuration = PostgresConnection.Configuration( + let configuration = PostgresConnection.InternalConfiguration( connection: .resolved(address: socketAddress, serverName: serverHostname), authentication: nil, tls: tls @@ -336,6 +387,10 @@ extension PostgresConnection { } } + @available(*, deprecated, + message: "Use the new connect method that allows you to connect and authenticate in a single step", + renamed: "connect(on:configuration:id:logger:)" + ) public func authenticate( username: String, database: String? = nil, @@ -359,7 +414,31 @@ extension PostgresConnection { #if swift(>=5.5) && canImport(_Concurrency) extension PostgresConnection { - func close() async throws { + + /// Creates a new connection to a Postgres server. + /// + /// - Parameters: + /// - eventLoop: The `EventLoop` the request shall be created on + /// - configuration: A ``Configuration`` that shall be used for the connection + /// - connectionID: An `Int` id, used for metadata logging + /// - logger: A logger to log background events into + /// - Returns: An established ``PostgresConnection`` asynchronously that can be used to run queries. + public static func connect( + on eventLoop: EventLoop, + configuration: PostgresConnection.Configuration, + id connectionID: ID, + logger: Logger + ) async throws -> PostgresConnection { + try await self.connect( + connectionID: connectionID, + configuration: .init(configuration), + logger: logger, + on: eventLoop + ).get() + } + + /// Closes the connection to the server. + public func close() async throws { try await self.close().get() } @@ -367,20 +446,18 @@ extension PostgresConnection { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.id)" - do { - guard query.binds.count <= Int(Int16.max) else { - throw PSQLError.tooManyParameters - } - let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) - let context = ExtendedQueryContext( - query: query, - logger: logger, - promise: promise) + guard query.binds.count <= Int(Int16.max) else { + throw PSQLError.tooManyParameters + } + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + logger: logger, + promise: promise) - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + self.channel.write(PSQLTask.extendedQuery(context), promise: nil) - return try await promise.futureResult.map({ $0.asyncSequence() }).get() - } + return try await promise.futureResult.map({ $0.asyncSequence() }).get() } } #endif @@ -539,7 +616,7 @@ enum CloseTarget { case portal(String) } -extension PostgresConnection.Configuration { +extension PostgresConnection.InternalConfiguration { var sslServerHostname: String? { switch self.connection { case .unresolved(let host, _): @@ -567,6 +644,33 @@ private extension String { } } +extension PostgresConnection { + /// A configuration object to bring the new ``PostgresConnection.Configuration`` together with + /// the deprecated configuration. + /// + /// TODO: Drop with next major release + struct InternalConfiguration { + enum Connection { + case unresolved(host: String, port: Int) + case resolved(address: SocketAddress, serverName: String?) + } + + var connection: Connection + + var authentication: Configuration.Authentication? + + var tls: Configuration.TLS + } +} + +extension PostgresConnection.InternalConfiguration { + init(_ config: PostgresConnection.Configuration) { + self.authentication = config.authentication + self.connection = .unresolved(host: config.connection.host, port: config.connection.port) + self.tls = config.tls + } +} + #if swift(>=5.6) extension PostgresConnection: @unchecked Sendable {} #endif diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index 812cd358..d6dcd253 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -26,13 +26,13 @@ final class PSQLChannelHandler: ChannelDuplexHandler { private var rowStream: PSQLRowStream? private var decoder: NIOSingleStepByteToMessageProcessor private var encoder: BufferedMessageEncoder! - private let configuration: PostgresConnection.Configuration + private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? /// this delegate should only be accessed on the connections `EventLoop` weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate? - init(configuration: PostgresConnection.Configuration, + init(configuration: PostgresConnection.InternalConfiguration, logger: Logger, configureSSLCallback: ((Channel) throws -> Void)?) { @@ -45,7 +45,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { #if DEBUG /// for testing purposes only - init(configuration: PostgresConnection.Configuration, + init(configuration: PostgresConnection.InternalConfiguration, state: ConnectionStateMachine = .init(.initialized), logger: Logger = .psqlNoOpLogger, configureSSLCallback: ((Channel) throws -> Void)?) @@ -575,8 +575,8 @@ private extension Insecure.MD5.Digest { } extension ConnectionStateMachine.TLSConfiguration { - fileprivate init(_ connection: PostgresConnection.Configuration.TLS) { - switch connection.base { + fileprivate init(_ tls: PostgresConnection.Configuration.TLS) { + switch tls.base { case .disable: self = .disable case .require: @@ -589,7 +589,7 @@ extension ConnectionStateMachine.TLSConfiguration { extension PSQLChannelHandler { convenience init( - configuration: PostgresConnection.Configuration, + configuration: PostgresConnection.InternalConfiguration, configureSSLCallback: ((Channel) throws -> Void)?) { self.init( diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 38e41a20..723a8034 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -23,12 +23,17 @@ final class IntegrationTests: XCTestCase { try XCTSkipIf(env("POSTGRES_HOST_AUTH_METHOD") == "trust") let config = PostgresConnection.Configuration( - host: env("POSTGRES_HOSTNAME") ?? "localhost", - port: 5432, - username: env("POSTGRES_USER") ?? "test_username", - database: env("POSTGRES_DB") ?? "test_database", - password: "wrong_password", - tls: .disable) + connection: .init( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: 5432 + ), + authentication: .init( + username: env("POSTGRES_USER") ?? "test_username", + database: env("POSTGRES_DB") ?? "test_database", + password: "wrong_password" + ), + tls: .disable + ) let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -37,7 +42,7 @@ final class IntegrationTests: XCTestCase { logger.logLevel = .info var connection: PostgresConnection? - XCTAssertThrowsError(connection = try PostgresConnection.connect(connectionID: 1, configuration: config, logger: logger, on: eventLoopGroup.next()).wait()) { + XCTAssertThrowsError(connection = try PostgresConnection.connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger).wait()) { XCTAssertTrue($0 is PSQLError) } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index ee7ecaf0..4ff68806 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -704,17 +704,23 @@ final class PostgresNIOTests: XCTestCase { func testRemoteTLSServer() { // postgres://uymgphwj:7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA@elmer.db.elephantsql.com:5432/uymgphwj var conn: PostgresConnection? - XCTAssertNoThrow(conn = try PostgresConnection.connect( - to: SocketAddress.makeAddressResolvingHost("elmer.db.elephantsql.com", port: 5432), - tlsConfiguration: .makeClientConfiguration(), - serverHostname: "elmer.db.elephantsql.com", - on: eventLoop - ).wait()) - XCTAssertNoThrow(try conn?.authenticate( - username: "uymgphwj", - database: "uymgphwj", - password: "7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA" - ).wait()) + let logger = Logger(label: "test") + let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration()) + let config = PostgresConnection.Configuration( + connection: .init( + host: "elmer.db.elephantsql.com", + port: 5432 + ), + authentication: .init( + username: "uymgphwj", + database: "uymgphwj", + password: "7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA" + ), + tls: .require(sslContext) + ) + + + XCTAssertNoThrow(conn = try PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } var rows: [PostgresRow]? XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) @@ -723,6 +729,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "version"].string?.contains("PostgreSQL"), true) } + @available(*, deprecated, message: "Test deprecated functionality") func testFailingTLSConnectionClosesConnection() { // There was a bug (https://github.com/vapor/postgres-nio/issues/133) where we would hit // an assert because we didn't close the connection. This test should succeed without hitting @@ -744,6 +751,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertTrue(true) } + @available(*, deprecated, message: "Test deprecated functionality") func testInvalidPassword() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.testUnauthenticated(on: eventLoop).wait()) diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index 070122d1..faa19c42 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -13,6 +13,7 @@ extension PostgresConnection { try .makeAddressResolvingHost(env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432) } + @available(*, deprecated, message: "Test deprecated functionality") static func testUnauthenticated(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { var logger = Logger(label: "postgres.connection.test") logger.logLevel = logLevel @@ -24,19 +25,23 @@ extension PostgresConnection { } static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - return testUnauthenticated(on: eventLoop, logLevel: logLevel).flatMap { conn in - return conn.authenticate( + var logger = Logger(label: "postgres.connection.test") + logger.logLevel = logLevel + + let config = PostgresConnection.Configuration( + connection: .init( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: 5432 + ), + authentication: .init( username: env("POSTGRES_USER") ?? "test_username", database: env("POSTGRES_DB") ?? "test_database", password: env("POSTGRES_PASSWORD") ?? "test_password" - ).map { - return conn - }.flatMapError { error in - conn.close().flatMapThrowing { - throw error - } - } - } + ), + tls: .disable + ) + + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } } diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift index 01b830c4..36eac812 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift @@ -174,13 +174,16 @@ class PSQLChannelHandlerTests: XCTestCase { database: String = "postgres", password: String = "password", tls: PostgresConnection.Configuration.TLS = .disable - ) -> PostgresConnection.Configuration { - PostgresConnection.Configuration( - host: host, - port: port, + ) -> PostgresConnection.InternalConfiguration { + let authentication = PostgresConnection.Configuration.Authentication( username: username, database: database, - password: password, + password: password + ) + + return PostgresConnection.InternalConfiguration( + connection: .unresolved(host: host, port: port), + authentication: authentication, tls: tls ) } diff --git a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift index 260705c2..2d50cb0f 100644 --- a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift @@ -22,18 +22,15 @@ class PSQLConnectionTests: XCTestCase { } let config = PostgresConnection.Configuration( - host: "127.0.0.1", - port: port, - username: "postgres", - database: "postgres", - password: "abc123", + connection: .init(host: "127.0.0.1", port: port), + authentication: .init(username: "postgres", database: "postgres", password: "abc123"), tls: .disable ) var logger = Logger.psqlTest logger.logLevel = .trace - XCTAssertThrowsError(try PostgresConnection.connect(connectionID: 1, configuration: config, logger: logger, on: eventLoopGroup.next()).wait()) { + XCTAssertThrowsError(try PostgresConnection.connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger).wait()) { XCTAssertTrue($0 is PSQLError) } } From ba0d2bbf762c955d4c8564aa13244901a66eb3de Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 8 Mar 2022 21:24:33 +0100 Subject: [PATCH 081/292] Make new Postgres decoding public (#244) --- Sources/PostgresNIO/Data/PostgresRow.swift | 4 +- .../New/Data/Array+PostgresCodable.swift | 45 ++++++++------ .../New/Data/Bool+PostgresCodable.swift | 3 +- .../New/Data/Bytes+PostgresCodable.swift | 6 +- .../New/Data/Date+PostgresCodable.swift | 14 +++-- .../New/Data/Decimal+PostgresCodable.swift | 2 +- .../New/Data/Float+PostgresCodable.swift | 6 +- .../New/Data/Int+PostgresCodable.swift | 15 +++-- .../New/Data/String+PostgresCodable.swift | 4 +- .../New/Data/UUID+PostgresCodable.swift | 4 +- .../New/Extensions/ByteBuffer+PSQL.swift | 2 + .../PostgresNIO/New/Messages/DataRow.swift | 41 +++++++------ .../New/Messages/RowDescription.swift | 12 +++- Sources/PostgresNIO/New/PSQLError.swift | 31 +++++++++- Sources/PostgresNIO/New/PostgresCell.swift | 3 +- Sources/PostgresNIO/New/PostgresCodable.swift | 11 ++-- .../New/PostgresRow-multi-decode.swift | 60 ++++++++++++++----- dev/generate-postgresrow-multi-decode.sh | 6 +- 18 files changed, 190 insertions(+), 79 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 83343812..3fda262a 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -12,9 +12,11 @@ import class Foundation.JSONDecoder /// random access to cells in O(1) create a new ``PostgresRandomAccessRow`` with the given row and /// access it instead. public struct PostgresRow { + @usableFromInline let lookupTable: [String: Int] + @usableFromInline let data: DataRow - + @usableFromInline let columns: [RowDescription.Column] init(data: DataRow, lookupTable: [String: Int], columns: [RowDescription.Column]) { diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index 91edc9a1..aae9ad32 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -103,8 +103,23 @@ extension Array: PostgresEncodable where Element: PSQLArrayElement { } } -extension Array: PostgresDecodable where Element: PSQLArrayElement { - init( +/// A type that can be decoded into a Swift Array of its own type from a Postgres array. +public protocol PostgresArrayDecodable: PostgresDecodable {} + +extension Bool: PostgresArrayDecodable {} +extension ByteBuffer: PostgresArrayDecodable {} +extension UInt8: PostgresArrayDecodable {} +extension Int16: PostgresArrayDecodable {} +extension Int32: PostgresArrayDecodable {} +extension Int64: PostgresArrayDecodable {} +extension Int: PostgresArrayDecodable {} +extension Float: PostgresArrayDecodable {} +extension Double: PostgresArrayDecodable {} +extension String: PostgresArrayDecodable {} +extension UUID: PostgresArrayDecodable {} + +extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Element == Element._DecodableType { + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -114,48 +129,44 @@ extension Array: PostgresDecodable where Element: PSQLArrayElement { // currently we only support decoding arrays in binary format. throw PostgresCastingError.Code.failure } - + guard let (isNotEmpty, b, element) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32, UInt32).self), 0 <= isNotEmpty, isNotEmpty <= 1, b == 0 else { throw PostgresCastingError.Code.failure } - + let elementType = PostgresDataType(element) - + guard isNotEmpty == 1 else { self = [] return } - + guard let (expectedArrayCount, dimensions) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32).self), expectedArrayCount > 0, dimensions == 1 else { throw PostgresCastingError.Code.failure } - + var result = Array() result.reserveCapacity(Int(expectedArrayCount)) - + for _ in 0 ..< expectedArrayCount { - guard let elementLength = buffer.readInteger(as: Int32.self) else { + guard let elementLength = buffer.readInteger(as: Int32.self), elementLength >= 0 else { throw PostgresCastingError.Code.failure } - + guard var elementBuffer = buffer.readSlice(length: numericCast(elementLength)) else { throw PostgresCastingError.Code.failure } - + let element = try Element.init(from: &elementBuffer, type: elementType, format: format, context: context) - + result.append(element) } - + self = result } } - -extension Array: PostgresCodable where Element: PSQLArrayElement { - -} diff --git a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift index 88609d13..baf828aa 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -1,7 +1,8 @@ import NIOCore extension Bool: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift index 168d9c69..53d6df17 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift @@ -38,7 +38,8 @@ extension ByteBuffer: PostgresEncodable { } extension ByteBuffer: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -68,7 +69,8 @@ extension Data: PostgresEncodable { } extension Data: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift index 8c164f1c..960f3c02 100644 --- a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -19,16 +19,20 @@ extension Date: PostgresEncodable { } // MARK: Private Constants - - private static let _microsecondsPerSecond: Int64 = 1_000_000 - private static let _secondsInDay: Int64 = 24 * 60 * 60 + + @usableFromInline + static let _microsecondsPerSecond: Int64 = 1_000_000 + @usableFromInline + static let _secondsInDay: Int64 = 24 * 60 * 60 /// values are stored as seconds before or after midnight 2000-01-01 - private static let _psqlDateStart = Date(timeIntervalSince1970: 946_684_800) + @usableFromInline + static let _psqlDateStart = Date(timeIntervalSince1970: 946_684_800) } extension Date: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift index e80da7be..43432302 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift @@ -25,7 +25,7 @@ extension Decimal: PostgresEncodable { } extension Decimal: PostgresDecodable { - init( + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift index 1a39be18..8951e8b2 100644 --- a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -18,7 +18,8 @@ extension Float: PostgresEncodable { } extension Float: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -66,7 +67,8 @@ extension Double: PostgresEncodable { } extension Double: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift index e399a406..3b15cae0 100644 --- a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -20,7 +20,8 @@ extension UInt8: PostgresEncodable { } extension UInt8: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -62,7 +63,8 @@ extension Int16: PostgresEncodable { } extension Int16: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -107,7 +109,8 @@ extension Int32: PostgresEncodable { } extension Int32: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -157,7 +160,8 @@ extension Int64: PostgresEncodable { } extension Int64: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -219,7 +223,8 @@ extension Int: PostgresEncodable { } extension Int: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 56080540..c00a3829 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -19,7 +19,9 @@ extension String: PostgresEncodable { } extension String: PostgresDecodable { - init( + + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index 2ec813bd..facb7e95 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -26,7 +26,8 @@ extension UUID: PostgresEncodable { } extension UUID: PostgresDecodable { - init( + @inlinable + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -60,6 +61,7 @@ extension UUID: PostgresDecodable { extension UUID: PostgresCodable {} extension ByteBuffer { + @usableFromInline mutating func readUUID() -> UUID? { guard self.readableBytes >= MemoryLayout.size else { return nil diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index f226bd7b..9543ffd1 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -10,10 +10,12 @@ internal extension ByteBuffer { self.writeInteger(messageID.rawValue) } + @usableFromInline mutating func psqlReadFloat() -> Float? { return self.readInteger(as: UInt32.self).map { Float(bitPattern: $0) } } + @usableFromInline mutating func psqlReadDouble() -> Double? { return self.readInteger(as: UInt64.self).map { Double(bitPattern: $0) } } diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index b49c9eeb..b5c3f8e7 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -8,10 +8,11 @@ import NIOCore /// enclosing type, the enclosing type must be @usableFromInline as well. /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler +@usableFromInline struct DataRow: PostgresBackendMessage.PayloadDecodable, Equatable { - + @usableFromInline var columnCount: Int16 - + @usableFromInline var bytes: ByteBuffer static func decode(from buffer: inout ByteBuffer) throws -> Self { @@ -35,43 +36,48 @@ struct DataRow: PostgresBackendMessage.PayloadDecodable, Equatable { } extension DataRow: Sequence { + @usableFromInline typealias Element = ByteBuffer? - - // There is no contiguous storage available... Sadly - func withContiguousStorageIfAvailable(_ body: (UnsafeBufferPointer) throws -> R) rethrows -> R? { - nil - } } extension DataRow: Collection { - + + @usableFromInline struct ColumnIndex: Comparable { + @usableFromInline var offset: Int - + + @inlinable init(_ index: Int) { self.offset = index } // Only needed implementation for comparable. The compiler synthesizes the rest from this. + @inlinable static func < (lhs: Self, rhs: Self) -> Bool { lhs.offset < rhs.offset } } - + + @usableFromInline typealias Index = DataRow.ColumnIndex - + + @inlinable var startIndex: ColumnIndex { ColumnIndex(self.bytes.readerIndex) } - + + @inlinable var endIndex: ColumnIndex { ColumnIndex(self.bytes.readerIndex + self.bytes.readableBytes) } - + + @inlinable var count: Int { Int(self.columnCount) } - + + @inlinable func index(after index: ColumnIndex) -> ColumnIndex { guard index < self.endIndex else { preconditionFailure("index out of bounds") @@ -82,7 +88,8 @@ extension DataRow: Collection { } return ColumnIndex(index.offset + MemoryLayout.size + elementLength) } - + + @inlinable subscript(index: ColumnIndex) -> Element { guard index < self.endIndex else { preconditionFailure("index out of bounds") @@ -100,12 +107,12 @@ extension DataRow { guard index < self.columnCount else { preconditionFailure("index out of bounds") } - + var byteIndex = self.startIndex for _ in 0.. Bool { return lhs.code == rhs.code && lhs.columnName == rhs.columnName diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index 624e845d..5281a798 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -24,7 +24,8 @@ public struct PostgresCell: Equatable { extension PostgresCell { - func decode( + @inlinable + public func decode( _: T.Type, context: PostgresDecodingContext, file: String = #file, diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index b197fdd6..55c55df4 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -17,7 +17,7 @@ protocol PostgresEncodable { /// A type that can decode itself from a postgres wire binary representation. /// /// If you want to conform a type to PostgresDecodable you must implement the decode method. -protocol PostgresDecodable { +public protocol PostgresDecodable { /// A type definition of the type that actually implements the PostgresDecodable protocol. This is an escape hatch to /// prevent a cycle in the conformace of the Optional type to PostgresDecodable. /// @@ -54,7 +54,7 @@ protocol PostgresDecodable { extension PostgresDecodable { @inlinable - static func _decodeRaw( + public static func _decodeRaw( from byteBuffer: inout ByteBuffer?, type: PostgresDataType, format: PostgresFormat, @@ -133,9 +133,9 @@ extension PostgresDecodingContext where JSONDecoder == Foundation.JSONDecoder { } extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped._DecodableType == Wrapped { - typealias _DecodableType = Wrapped + public typealias _DecodableType = Wrapped - init( + public init( from byteBuffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, @@ -144,7 +144,8 @@ extension Optional: PostgresDecodable where Wrapped: PostgresDecodable, Wrapped. preconditionFailure("This should not be called") } - static func _decodeRaw( + @inlinable + public static func _decodeRaw( from byteBuffer: inout ByteBuffer?, type: PostgresDataType, format: PostgresFormat, diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift index 1e1a426d..6ca0e54b 100644 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift @@ -1,7 +1,9 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh extension PostgresRow { - func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0) { precondition(self.columns.count >= 1) let columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -29,7 +31,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1) { precondition(self.columns.count >= 2) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -63,7 +67,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { precondition(self.columns.count >= 3) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -103,7 +109,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { precondition(self.columns.count >= 4) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -149,7 +157,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { precondition(self.columns.count >= 5) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -201,7 +211,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { precondition(self.columns.count >= 6) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -259,7 +271,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { precondition(self.columns.count >= 7) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -323,7 +337,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { precondition(self.columns.count >= 8) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -393,7 +409,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { precondition(self.columns.count >= 9) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -469,7 +487,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { precondition(self.columns.count >= 10) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -551,7 +571,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { precondition(self.columns.count >= 11) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -639,7 +661,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { precondition(self.columns.count >= 12) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -733,7 +757,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { precondition(self.columns.count >= 13) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -833,7 +859,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { precondition(self.columns.count >= 14) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -939,7 +967,9 @@ extension PostgresRow { } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { precondition(self.columns.count >= 15) var columnIndex = 0 var cellIterator = self.data.makeIterator() diff --git a/dev/generate-postgresrow-multi-decode.sh b/dev/generate-postgresrow-multi-decode.sh index b99be562..2fb98c24 100755 --- a/dev/generate-postgresrow-multi-decode.sh +++ b/dev/generate-postgresrow-multi-decode.sh @@ -11,9 +11,9 @@ function gen() { echo "" fi - #echo " @inlinable" - #echo " @_alwaysEmitIntoClient" - echo -n " func decode Date: Tue, 8 Mar 2022 23:17:55 +0100 Subject: [PATCH 082/292] Make PostgresRowSequence public (#247) --- .../PostgresRowSequence-multi-decode.swift | 60 ++++++++++++++----- .../PostgresNIO/New/PostgresRowSequence.swift | 19 +++--- Tests/IntegrationTests/AsyncTests.swift | 4 +- .../New/PostgresRowSequenceTests.swift | 2 +- ...nerate-postgresrowsequence-multi-decode.sh | 6 +- 5 files changed, 60 insertions(+), 31 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift index aea721e4..0b3302c1 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift @@ -2,91 +2,121 @@ #if swift(>=5.5) && canImport(_Concurrency) extension PostgresRowSequence { - func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode(T0.self, context: context, file: file, line: line) } } - func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: context, file: file, line: line) } } - func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: context, file: file, line: line) } diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 4a87b452..8159e679 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -5,9 +5,8 @@ import NIOConcurrencyHelpers /// An async sequence of ``PostgresRow``s. /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. -struct PostgresRowSequence: AsyncSequence { - typealias Element = PostgresRow - typealias AsyncIterator = Iterator +public struct PostgresRowSequence: AsyncSequence { + public typealias Element = PostgresRow final class _Internal { @@ -22,7 +21,7 @@ struct PostgresRowSequence: AsyncSequence { self.consumer.sequenceDeinitialized() } - func makeAsyncIterator() -> Iterator { + func makeAsyncIterator() -> AsyncIterator { self.consumer.makeAsyncIterator() } } @@ -33,14 +32,14 @@ struct PostgresRowSequence: AsyncSequence { self._internal = .init(consumer: consumer) } - func makeAsyncIterator() -> Iterator { + public func makeAsyncIterator() -> AsyncIterator { self._internal.makeAsyncIterator() } } extension PostgresRowSequence { - struct Iterator: AsyncIteratorProtocol { - typealias Element = PostgresRow + public struct AsyncIterator: AsyncIteratorProtocol { + public typealias Element = PostgresRow let _internal: _Internal @@ -48,7 +47,7 @@ extension PostgresRowSequence { self._internal = _Internal(consumer: consumer) } - mutating func next() async throws -> PostgresRow? { + public mutating func next() async throws -> PostgresRow? { try await self._internal.next() } @@ -155,11 +154,11 @@ final class AsyncStreamConsumer { } } - func makeAsyncIterator() -> PostgresRowSequence.Iterator { + func makeAsyncIterator() -> PostgresRowSequence.AsyncIterator { self.lock.withLock { self.state.createAsyncIterator() } - let iterator = PostgresRowSequence.Iterator(consumer: self) + let iterator = PostgresRowSequence.AsyncIterator(consumer: self) return iterator } diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 593a06e0..691c334f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -33,8 +33,8 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await withTestConnection(on: eventLoop) { connection in let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) var counter = 1 - for try await row in rows { - XCTAssertEqual(try row.decode(Int.self, context: .default), counter) + for try await element in rows.decode(Int.self, context: .default) { + XCTAssertEqual(element, counter) counter += 1 } diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 9e01ff06..6d7bc24b 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -90,7 +90,7 @@ final class PostgresRowSequenceTests: XCTestCase { let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) - var iterator: PostgresRowSequence.Iterator? = rowSequence.makeAsyncIterator() + var iterator: PostgresRowSequence.AsyncIterator? = rowSequence.makeAsyncIterator() iterator = nil XCTAssertEqual(dataSource.cancelCount, 1) diff --git a/dev/generate-postgresrowsequence-multi-decode.sh b/dev/generate-postgresrowsequence-multi-decode.sh index eb5ad9a0..284b0049 100755 --- a/dev/generate-postgresrowsequence-multi-decode.sh +++ b/dev/generate-postgresrowsequence-multi-decode.sh @@ -11,9 +11,9 @@ function gen() { echo "" fi - #echo " @inlinable" - #echo " @_alwaysEmitIntoClient" - echo -n " func decode Date: Fri, 11 Mar 2022 09:05:38 +0100 Subject: [PATCH 083/292] Make Postgres Encodable public (#248) --- .../New/Data/Array+PostgresCodable.swift | 131 +++++++++--------- .../New/Data/Bool+PostgresCodable.swift | 9 +- .../New/Data/Bytes+PostgresCodable.swift | 25 ++-- .../New/Data/Date+PostgresCodable.swift | 9 +- .../New/Data/Decimal+PostgresCodable.swift | 8 +- .../New/Data/Float+PostgresCodable.swift | 18 +-- .../New/Data/Int+PostgresCodable.swift | 50 ++++--- .../New/Data/JSON+PostgresCodable.swift | 16 ++- .../RawRepresentable+PostgresCodable.swift | 13 +- .../New/Data/String+PostgresCodable.swift | 9 +- .../New/Data/UUID+PostgresCodable.swift | 9 +- .../New/Extensions/ByteBuffer+PSQL.swift | 2 + Sources/PostgresNIO/New/PostgresCodable.swift | 7 +- Sources/PostgresNIO/New/PostgresQuery.swift | 2 +- .../PSQLIntegrationTests.swift | 6 +- .../New/Data/Array+PSQLCodableTests.swift | 60 ++++---- .../New/Data/Bool+PSQLCodableTests.swift | 8 +- .../New/Data/Bytes+PSQLCodableTests.swift | 6 +- .../New/Data/Date+PSQLCodableTests.swift | 2 +- .../New/Data/Decimal+PSQLCodableTests.swift | 2 +- .../New/Data/Float+PSQLCodableTests.swift | 12 +- .../New/Data/JSON+PSQLCodableTests.swift | 2 +- .../RawRepresentable+PSQLCodableTests.swift | 8 +- .../New/Data/String+PSQLCodableTests.swift | 2 +- .../New/Data/UUID+PSQLCodableTests.swift | 4 +- 25 files changed, 224 insertions(+), 196 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index aae9ad32..dd4e5620 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -1,82 +1,103 @@ import NIOCore import struct Foundation.UUID +// MARK: Protocols + /// A type, of which arrays can be encoded into and decoded from a postgres binary format -protocol PSQLArrayElement: PostgresCodable { +public protocol PostgresArrayEncodable: PostgresEncodable { static var psqlArrayType: PostgresDataType { get } - static var psqlArrayElementType: PostgresDataType { get } } -extension Bool: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .boolArray } - static var psqlArrayElementType: PostgresDataType { .bool } +/// A type that can be decoded into a Swift Array of its own type from a Postgres array. +public protocol PostgresArrayDecodable: PostgresDecodable {} + +// MARK: Element conformances + +extension Bool: PostgresArrayDecodable {} + +extension Bool: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .boolArray } } -extension ByteBuffer: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .byteaArray } - static var psqlArrayElementType: PostgresDataType { .bytea } +extension ByteBuffer: PostgresArrayDecodable {} + +extension ByteBuffer: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .byteaArray } } -extension UInt8: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .charArray } - static var psqlArrayElementType: PostgresDataType { .char } +extension UInt8: PostgresArrayDecodable {} + +extension UInt8: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .charArray } } -extension Int16: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .int2Array } - static var psqlArrayElementType: PostgresDataType { .int2 } + +extension Int16: PostgresArrayDecodable {} + +extension Int16: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .int2Array } } -extension Int32: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .int4Array } - static var psqlArrayElementType: PostgresDataType { .int4 } +extension Int32: PostgresArrayDecodable {} + +extension Int32: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .int4Array } } -extension Int64: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .int8Array } - static var psqlArrayElementType: PostgresDataType { .int8 } +extension Int64: PostgresArrayDecodable {} + +extension Int64: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .int8Array } } -extension Int: PSQLArrayElement { - #if (arch(i386) || arch(arm)) - static var psqlArrayType: PostgresDataType { .int4Array } - static var psqlArrayElementType: PostgresDataType { .int4 } - #else - static var psqlArrayType: PostgresDataType { .int8Array } - static var psqlArrayElementType: PostgresDataType { .int8 } - #endif +extension Int: PostgresArrayDecodable {} + +extension Int: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { + if MemoryLayout.size == 8 { + return .int8Array + } + return .int4Array + } } -extension Float: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .float4Array } - static var psqlArrayElementType: PostgresDataType { .float4 } +extension Float: PostgresArrayDecodable {} + +extension Float: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .float4Array } } -extension Double: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .float8Array } - static var psqlArrayElementType: PostgresDataType { .float8 } +extension Double: PostgresArrayDecodable {} + +extension Double: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .float8Array } } -extension String: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .textArray } - static var psqlArrayElementType: PostgresDataType { .text } +extension String: PostgresArrayDecodable {} + +extension String: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .textArray } } -extension UUID: PSQLArrayElement { - static var psqlArrayType: PostgresDataType { .uuidArray } - static var psqlArrayElementType: PostgresDataType { .uuid } +extension UUID: PostgresArrayDecodable {} + +extension UUID: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .uuidArray } } -extension Array: PostgresEncodable where Element: PSQLArrayElement { - var psqlType: PostgresDataType { +// MARK: Array conformances + +extension Array: PostgresEncodable where Element: PostgresArrayEncodable { + public static var psqlType: PostgresDataType { Element.psqlArrayType } - - var psqlFormat: PostgresFormat { + + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into buffer: inout ByteBuffer, context: PostgresEncodingContext ) throws { @@ -85,13 +106,13 @@ extension Array: PostgresEncodable where Element: PSQLArrayElement { // b buffer.writeInteger(0, as: Int32.self) // array element type - buffer.writeInteger(Element.psqlArrayElementType.rawValue) + buffer.writeInteger(Element.psqlType.rawValue) // continue if the array is not empty guard !self.isEmpty else { return } - + // length of array buffer.writeInteger(numericCast(self.count), as: Int32.self) // dimensions @@ -103,20 +124,6 @@ extension Array: PostgresEncodable where Element: PSQLArrayElement { } } -/// A type that can be decoded into a Swift Array of its own type from a Postgres array. -public protocol PostgresArrayDecodable: PostgresDecodable {} - -extension Bool: PostgresArrayDecodable {} -extension ByteBuffer: PostgresArrayDecodable {} -extension UInt8: PostgresArrayDecodable {} -extension Int16: PostgresArrayDecodable {} -extension Int32: PostgresArrayDecodable {} -extension Int64: PostgresArrayDecodable {} -extension Int: PostgresArrayDecodable {} -extension Float: PostgresArrayDecodable {} -extension Double: PostgresArrayDecodable {} -extension String: PostgresArrayDecodable {} -extension UUID: PostgresArrayDecodable {} extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Element == Element._DecodableType { public init( diff --git a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift index baf828aa..13308265 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -44,15 +44,16 @@ extension Bool: PostgresDecodable { } extension Bool: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .bool } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { diff --git a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift index 53d6df17..edf79462 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift @@ -3,15 +3,16 @@ import NIOCore import NIOFoundationCompat extension PostgresEncodable where Self: Sequence, Self.Element == UInt8 { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .bytea } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { @@ -20,15 +21,16 @@ extension PostgresEncodable where Self: Sequence, Self.Element == UInt8 { } extension ByteBuffer: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .bytea } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { @@ -52,15 +54,16 @@ extension ByteBuffer: PostgresDecodable { extension ByteBuffer: PostgresCodable {} extension Data: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .bytea } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - func encode( + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { diff --git a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift index 960f3c02..4a1848ec 100644 --- a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -2,15 +2,16 @@ import NIOCore import struct Foundation.Date extension Date: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .timestamptz } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { diff --git a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift index 43432302..3f1c7fa0 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift @@ -2,15 +2,15 @@ import NIOCore import struct Foundation.Decimal extension Decimal: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .numeric } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { diff --git a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift index 8951e8b2..d653e9d8 100644 --- a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -1,15 +1,16 @@ import NIOCore extension Float: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .float4 } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { @@ -50,15 +51,16 @@ extension Float: PostgresDecodable { extension Float: PostgresCodable {} extension Double: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .float8 } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { diff --git a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift index 3b15cae0..7ea81f31 100644 --- a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -3,15 +3,16 @@ import NIOCore // MARK: UInt8 extension UInt8: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .char } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - func encode( + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { @@ -45,16 +46,16 @@ extension UInt8: PostgresCodable {} // MARK: Int16 extension Int16: PostgresEncodable { - - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .int2 } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - func encode( + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { @@ -92,15 +93,16 @@ extension Int16: PostgresCodable {} // MARK: Int32 extension Int32: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .int4 } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - func encode( + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { @@ -143,15 +145,16 @@ extension Int32: PostgresCodable {} // MARK: Int64 extension Int64: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .int8 } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - func encode( + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { @@ -199,22 +202,23 @@ extension Int64: PostgresCodable {} // MARK: Int extension Int: PostgresEncodable { - var psqlType: PostgresDataType { - switch self.bitWidth { - case Int32.bitWidth: + public static var psqlType: PostgresDataType { + switch MemoryLayout.size { + case 4: return .int4 - case Int64.bitWidth: + case 8: return .int8 default: preconditionFailure("Int is expected to be an Int32 or Int64") } } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - func encode( + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { @@ -237,12 +241,12 @@ extension Int: PostgresDecodable { } self = Int(value) case (.binary, .int4): - guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { + guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self).flatMap({ Int(exactly: $0) }) else { throw PostgresCastingError.Code.failure } - self = Int(value) - case (.binary, .int8) where Int.bitWidth == 64: - guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int.self) else { + self = value + case (.binary, .int8): + guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int.self).flatMap({ Int(exactly: $0) }) else { throw PostgresCastingError.Code.failure } self = value diff --git a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift index a506c2d6..2e09d03e 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift @@ -3,18 +3,20 @@ import NIOFoundationCompat import class Foundation.JSONEncoder import class Foundation.JSONDecoder -private let JSONBVersionByte: UInt8 = 0x01 +@usableFromInline +let JSONBVersionByte: UInt8 = 0x01 -extension PostgresEncodable where Self: Codable { - var psqlType: PostgresDataType { +extension PostgresEncodable where Self: Encodable { + public static var psqlType: PostgresDataType { .jsonb } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) throws { @@ -23,7 +25,7 @@ extension PostgresEncodable where Self: Codable { } } -extension PostgresDecodable where Self: Codable { +extension PostgresDecodable where Self: Decodable { init( from buffer: inout ByteBuffer, type: PostgresDataType, diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift index c64da931..9a4f6b1d 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -1,15 +1,16 @@ import NIOCore extension PostgresEncodable where Self: RawRepresentable, RawValue: PostgresEncodable { - var psqlType: PostgresDataType { - self.rawValue.psqlType + public static var psqlType: PostgresDataType { + RawValue.psqlType } - var psqlFormat: PostgresFormat { - self.rawValue.psqlFormat + public static var psqlFormat: PostgresFormat { + RawValue.psqlFormat } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) throws { diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index c00a3829..aebfedcd 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -2,15 +2,16 @@ import NIOCore import struct Foundation.UUID extension String: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .text } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index facb7e95..f40fff7c 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -3,15 +3,16 @@ import struct Foundation.UUID import typealias Foundation.uuid_t extension UUID: PostgresEncodable { - var psqlType: PostgresDataType { + public static var psqlType: PostgresDataType { .uuid } - var psqlFormat: PostgresFormat { + public static var psqlFormat: PostgresFormat { .binary } - - func encode( + + @inlinable + public func encode( into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 9543ffd1..6d632b6f 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -20,10 +20,12 @@ internal extension ByteBuffer { return self.readInteger(as: UInt64.self).map { Double(bitPattern: $0) } } + @usableFromInline mutating func psqlWriteFloat(_ float: Float) { self.writeInteger(float.bitPattern) } + @usableFromInline mutating func psqlWriteDouble(_ double: Double) { self.writeInteger(double.bitPattern) } diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 55c55df4..c90594cf 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -2,12 +2,12 @@ import NIOCore import Foundation /// A type that can encode itself to a postgres wire binary representation. -protocol PostgresEncodable { +public protocol PostgresEncodable { /// identifies the data type that we will encode into `byteBuffer` in `encode` - var psqlType: PostgresDataType { get } + static var psqlType: PostgresDataType { get } /// identifies the postgres format that is used to encode the value into `byteBuffer` in `encode` - var psqlFormat: PostgresFormat { get } + static var psqlFormat: PostgresFormat { get } /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting /// the byte count. This method is called from the ``PostgresBindings``. @@ -71,6 +71,7 @@ extension PostgresDecodable { protocol PostgresCodable: PostgresEncodable, PostgresDecodable {} extension PostgresEncodable { + @inlinable func encodeRaw( into buffer: inout ByteBuffer, context: PostgresEncodingContext diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index b1f00f0a..00687992 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -86,7 +86,7 @@ struct PostgresBindings: Hashable { } init(value: Value) { - self.init(dataType: value.psqlType, format: value.psqlFormat) + self.init(dataType: Value.psqlType, format: Value.psqlFormat) } } diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 723a8034..2e4de247 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -294,15 +294,17 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } + let uuidString = "2c68f645-9ca6-468b-b193-ee97f241c2f8" + var stream: PSQLRowStream? XCTAssertNoThrow(stream = try conn?.query(""" - SELECT '2c68f645-9ca6-468b-b193-ee97f241c2f8'::UUID as uuid + SELECT \(uuidString)::UUID as uuid """, logger: .psqlTest).wait()) var rows: [PostgresRow]? XCTAssertNoThrow(rows = try stream?.all().wait()) XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(try rows?.first?.decode(UUID.self, context: .default), UUID(uuidString: "2c68f645-9ca6-468b-b193-ee97f241c2f8")) + XCTAssertEqual(try rows?.first?.decode(UUID.self, context: .default), UUID(uuidString: uuidString)) } func testRoundTripJSONB() { diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index a7c40550..3798dab0 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -7,54 +7,54 @@ class Array_PSQLCodableTests: XCTestCase { func testArrayTypes() { XCTAssertEqual(Bool.psqlArrayType, .boolArray) - XCTAssertEqual(Bool.psqlArrayElementType, .bool) - XCTAssertEqual([Bool]().psqlType, .boolArray) + XCTAssertEqual(Bool.psqlType, .bool) + XCTAssertEqual([Bool].psqlType, .boolArray) XCTAssertEqual(ByteBuffer.psqlArrayType, .byteaArray) - XCTAssertEqual(ByteBuffer.psqlArrayElementType, .bytea) - XCTAssertEqual([ByteBuffer]().psqlType, .byteaArray) + XCTAssertEqual(ByteBuffer.psqlType, .bytea) + XCTAssertEqual([ByteBuffer].psqlType, .byteaArray) XCTAssertEqual(UInt8.psqlArrayType, .charArray) - XCTAssertEqual(UInt8.psqlArrayElementType, .char) - XCTAssertEqual([UInt8]().psqlType, .charArray) + XCTAssertEqual(UInt8.psqlType, .char) + XCTAssertEqual([UInt8].psqlType, .charArray) XCTAssertEqual(Int16.psqlArrayType, .int2Array) - XCTAssertEqual(Int16.psqlArrayElementType, .int2) - XCTAssertEqual([Int16]().psqlType, .int2Array) + XCTAssertEqual(Int16.psqlType, .int2) + XCTAssertEqual([Int16].psqlType, .int2Array) XCTAssertEqual(Int32.psqlArrayType, .int4Array) - XCTAssertEqual(Int32.psqlArrayElementType, .int4) - XCTAssertEqual([Int32]().psqlType, .int4Array) + XCTAssertEqual(Int32.psqlType, .int4) + XCTAssertEqual([Int32].psqlType, .int4Array) XCTAssertEqual(Int64.psqlArrayType, .int8Array) - XCTAssertEqual(Int64.psqlArrayElementType, .int8) - XCTAssertEqual([Int64]().psqlType, .int8Array) + XCTAssertEqual(Int64.psqlType, .int8) + XCTAssertEqual([Int64].psqlType, .int8Array) #if (arch(i386) || arch(arm)) XCTAssertEqual(Int.psqlArrayType, .int4Array) - XCTAssertEqual(Int.psqlArrayElementType, .int4) - XCTAssertEqual([Int]().psqlType, .int4Array) + XCTAssertEqual(Int.psqlType, .int4) + XCTAssertEqual([Int].psqlType, .int4Array) #else XCTAssertEqual(Int.psqlArrayType, .int8Array) - XCTAssertEqual(Int.psqlArrayElementType, .int8) - XCTAssertEqual([Int]().psqlType, .int8Array) + XCTAssertEqual(Int.psqlType, .int8) + XCTAssertEqual([Int].psqlType, .int8Array) #endif XCTAssertEqual(Float.psqlArrayType, .float4Array) - XCTAssertEqual(Float.psqlArrayElementType, .float4) - XCTAssertEqual([Float]().psqlType, .float4Array) + XCTAssertEqual(Float.psqlType, .float4) + XCTAssertEqual([Float].psqlType, .float4Array) XCTAssertEqual(Double.psqlArrayType, .float8Array) - XCTAssertEqual(Double.psqlArrayElementType, .float8) - XCTAssertEqual([Double]().psqlType, .float8Array) + XCTAssertEqual(Double.psqlType, .float8) + XCTAssertEqual([Double].psqlType, .float8Array) XCTAssertEqual(String.psqlArrayType, .textArray) - XCTAssertEqual(String.psqlArrayElementType, .text) - XCTAssertEqual([String]().psqlType, .textArray) + XCTAssertEqual(String.psqlType, .text) + XCTAssertEqual([String].psqlType, .textArray) XCTAssertEqual(UUID.psqlArrayType, .uuidArray) - XCTAssertEqual(UUID.psqlArrayElementType, .uuid) - XCTAssertEqual([UUID]().psqlType, .uuidArray) + XCTAssertEqual(UUID.psqlType, .uuid) + XCTAssertEqual([UUID].psqlType, .uuidArray) } func testStringArrayRoundTrip() { @@ -83,7 +83,7 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int32(2)) // invalid value buffer.writeInteger(Int32(0)) - buffer.writeInteger(String.psqlArrayElementType.rawValue) + buffer.writeInteger(String.psqlType.rawValue) XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) @@ -94,7 +94,7 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int32(0)) // is empty buffer.writeInteger(Int32(1)) // invalid value, must always be 0 - buffer.writeInteger(String.psqlArrayElementType.rawValue) + buffer.writeInteger(String.psqlType.rawValue) XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) @@ -115,7 +115,7 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // invalid value buffer.writeInteger(Int32(0)) - buffer.writeInteger(String.psqlArrayElementType.rawValue) + buffer.writeInteger(String.psqlType.rawValue) buffer.writeInteger(Int32(-123)) // expected element count buffer.writeInteger(Int32(1)) // dimensions... must be one @@ -128,7 +128,7 @@ class Array_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // invalid value buffer.writeInteger(Int32(0)) - buffer.writeInteger(String.psqlArrayElementType.rawValue) + buffer.writeInteger(String.psqlType.rawValue) buffer.writeInteger(Int32(1)) // expected element count buffer.writeInteger(Int32(2)) // dimensions... must be one @@ -141,7 +141,7 @@ class Array_PSQLCodableTests: XCTestCase { var unexpectedEndInElementLengthBuffer = ByteBuffer() unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // invalid value unexpectedEndInElementLengthBuffer.writeInteger(Int32(0)) - unexpectedEndInElementLengthBuffer.writeInteger(String.psqlArrayElementType.rawValue) + unexpectedEndInElementLengthBuffer.writeInteger(String.psqlType.rawValue) unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // expected element count unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 @@ -153,7 +153,7 @@ class Array_PSQLCodableTests: XCTestCase { var unexpectedEndInElementBuffer = ByteBuffer() unexpectedEndInElementBuffer.writeInteger(Int32(1)) // invalid value unexpectedEndInElementBuffer.writeInteger(Int32(0)) - unexpectedEndInElementBuffer.writeInteger(String.psqlArrayElementType.rawValue) + unexpectedEndInElementBuffer.writeInteger(String.psqlType.rawValue) unexpectedEndInElementBuffer.writeInteger(Int32(1)) // expected element count unexpectedEndInElementBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index 8f77bcea..9526fcd6 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -11,8 +11,8 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .bool) - XCTAssertEqual(value.psqlFormat, .binary) + XCTAssertEqual(Bool.psqlType, .bool) + XCTAssertEqual(Bool.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) @@ -26,8 +26,8 @@ class Bool_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .bool) - XCTAssertEqual(value.psqlFormat, .binary) + XCTAssertEqual(Bool.psqlType, .bool) + XCTAssertEqual(Bool.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 0) diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index b67c0b5e..9230aee7 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -9,7 +9,7 @@ class Bytes_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() data.encode(into: &buffer, context: .default) - XCTAssertEqual(data.psqlType, .bytea) + XCTAssertEqual(ByteBuffer.psqlType, .bytea) var result: Data? result = Data(from: &buffer, type: .bytea, format: .binary, context: .default) @@ -21,7 +21,7 @@ class Bytes_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() bytes.encode(into: &buffer, context: .default) - XCTAssertEqual(bytes.psqlType, .bytea) + XCTAssertEqual(ByteBuffer.psqlType, .bytea) var result: ByteBuffer? result = ByteBuffer(from: &buffer, type: .bytea, format: .binary, context: .default) @@ -47,7 +47,7 @@ class Bytes_PSQLCodableTests: XCTestCase { let sequence = ByteSequence() var buffer = ByteBuffer() sequence.encode(into: &buffer, context: .default) - XCTAssertEqual(sequence.psqlType, .bytea) + XCTAssertEqual(ByteSequence.psqlType, .bytea) XCTAssertEqual(buffer.readableBytes, 256) } } diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 38ce1d04..9fe0e67b 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -9,7 +9,7 @@ class Date_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .timestamptz) + XCTAssertEqual(Date.psqlType, .timestamptz) XCTAssertEqual(buffer.readableBytes, 8) var result: Date? diff --git a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift index 2898f998..cfb7f7e3 100644 --- a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift @@ -10,7 +10,7 @@ class Decimal_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .numeric) + XCTAssertEqual(Decimal.psqlType, .numeric) var result: Decimal? XCTAssertNoThrow(result = try Decimal(from: &buffer, type: .numeric, format: .binary, context: .default)) diff --git a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift index 3cac7e6f..9fd1bb9e 100644 --- a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift @@ -10,7 +10,7 @@ class Float_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .float8) + XCTAssertEqual(Double.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) var result: Double? @@ -25,7 +25,7 @@ class Float_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .float4) + XCTAssertEqual(Float.psqlType, .float4) XCTAssertEqual(buffer.readableBytes, 4) var result: Float? @@ -39,7 +39,7 @@ class Float_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .float8) + XCTAssertEqual(Double.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) var result: Double? @@ -52,7 +52,7 @@ class Float_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .float8) + XCTAssertEqual(Double.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) var result: Double? @@ -66,7 +66,7 @@ class Float_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .float4) + XCTAssertEqual(Float.psqlType, .float4) XCTAssertEqual(buffer.readableBytes, 4) var result: Double? @@ -81,7 +81,7 @@ class Float_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .float8) + XCTAssertEqual(Double.psqlType, .float8) XCTAssertEqual(buffer.readableBytes, 8) var result: Float? diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 46563973..dbaa43ee 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -16,7 +16,7 @@ class JSON_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() let hello = Hello(name: "world") XCTAssertNoThrow(try hello.encode(into: &buffer, context: .default)) - XCTAssertEqual(hello.psqlType, .jsonb) + XCTAssertEqual(Hello.psqlType, .jsonb) // verify jsonb prefix byte XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift index 99a250aa..a0808daf 100644 --- a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -16,11 +16,11 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { for value in values { var buffer = ByteBuffer() XCTAssertNoThrow(try value.encode(into: &buffer, context: .default)) - XCTAssertEqual(value.psqlType, Int16.psqlArrayElementType) + XCTAssertEqual(MyRawRepresentable.psqlType, Int16.psqlType) XCTAssertEqual(buffer.readableBytes, 2) var result: MyRawRepresentable? - XCTAssertNoThrow(result = try MyRawRepresentable(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .default)) + XCTAssertNoThrow(result = try MyRawRepresentable(from: &buffer, type: Int16.psqlType, format: .binary, context: .default)) XCTAssertEqual(value, result) } } @@ -29,7 +29,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int16(4)) // out of bounds - XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int16.psqlArrayElementType, format: .binary, context: .default)) { + XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int16.psqlType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } @@ -38,7 +38,7 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // out of bounds - XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int32.psqlArrayElementType, format: .binary, context: .default)) { + XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int32.psqlType, format: .binary, context: .default)) { XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) } } diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 9d2937e4..42edbda5 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -10,7 +10,7 @@ class String_PSQLCodableTests: XCTestCase { value.encode(into: &buffer, context: .default) - XCTAssertEqual(value.psqlType, .text) + XCTAssertEqual(String.psqlType, .text) XCTAssertEqual(buffer.readString(length: buffer.readableBytes), value) } diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 1df8001b..0693f7f4 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -11,8 +11,8 @@ class UUID_PSQLCodableTests: XCTestCase { uuid.encode(into: &buffer, context: .default) - XCTAssertEqual(uuid.psqlType, .uuid) - XCTAssertEqual(uuid.psqlFormat, .binary) + XCTAssertEqual(UUID.psqlType, .uuid) + XCTAssertEqual(UUID.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 16) var byteIterator = buffer.readableBytesView.makeIterator() From d9ba5770d68be5f33b99bfa4e352eacb21428f84 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 11 Mar 2022 09:17:05 +0100 Subject: [PATCH 084/292] Make Postgres async query public (#249) --- .../Connection/PostgresConnection.swift | 17 ++++- Sources/PostgresNIO/New/PostgresQuery.swift | 68 +++++++++++-------- Tests/IntegrationTests/AsyncTests.swift | 2 +- 3 files changed, 57 insertions(+), 30 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index d9f24117..390edcce 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -442,7 +442,22 @@ extension PostgresConnection { try await self.close().get() } - func query(_ query: PostgresQuery, logger: Logger, file: String = #file, line: UInt = #line) async throws -> PostgresRowSequence { + /// Run a query on the Postgres server the connection is connected to. + /// + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result. + /// The sequence be discarded. + @discardableResult + public func query( + _ query: PostgresQuery, + logger: Logger, + file: String = #file, + line: Int = #line + ) async throws -> PostgresRowSequence { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.id)" diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 00687992..5bb33988 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -1,51 +1,54 @@ -struct PostgresQuery: Hashable { +/// A Postgres SQL query, that can be executed on a Postgres server. Contains the raw sql string and bindings. +public struct PostgresQuery: Hashable { /// The query string - var sql: String + public var sql: String /// The query binds - var binds: PostgresBindings + public var binds: PostgresBindings - init(unsafeSQL sql: String, binds: PostgresBindings = PostgresBindings()) { + public init(unsafeSQL sql: String, binds: PostgresBindings = PostgresBindings()) { self.sql = sql self.binds = binds } } extension PostgresQuery: ExpressibleByStringInterpolation { - typealias StringInterpolation = Interpolation - - init(stringInterpolation: Interpolation) { + public init(stringInterpolation: StringInterpolation) { self.sql = stringInterpolation.sql self.binds = stringInterpolation.binds } - init(stringLiteral value: String) { + public init(stringLiteral value: String) { self.sql = value self.binds = PostgresBindings() } } extension PostgresQuery { - struct Interpolation: StringInterpolationProtocol { - typealias StringLiteralType = String + public struct StringInterpolation: StringInterpolationProtocol { + public typealias StringLiteralType = String + @usableFromInline var sql: String + @usableFromInline var binds: PostgresBindings - init(literalCapacity: Int, interpolationCount: Int) { + public init(literalCapacity: Int, interpolationCount: Int) { self.sql = "" self.binds = PostgresBindings(capacity: interpolationCount) } - mutating func appendLiteral(_ literal: String) { + public mutating func appendLiteral(_ literal: String) { self.sql.append(contentsOf: literal) } - mutating func appendInterpolation(_ value: Value) throws { + @inlinable + public mutating func appendInterpolation(_ value: Value) throws { try self.binds.append(value, context: .default) self.sql.append(contentsOf: "$\(self.binds.count)") } - mutating func appendInterpolation(_ value: Optional) throws { + @inlinable + public mutating func appendInterpolation(_ value: Optional) throws { switch value { case .none: self.binds.appendNull() @@ -56,7 +59,8 @@ extension PostgresQuery { self.sql.append(contentsOf: "$\(self.binds.count)") } - mutating func appendInterpolation( + @inlinable + public mutating func appendInterpolation( _ value: Value, context: PostgresEncodingContext ) throws { @@ -75,45 +79,61 @@ struct PSQLExecuteStatement { var rowDescription: RowDescription? } -struct PostgresBindings: Hashable { +public struct PostgresBindings: Hashable { + @usableFromInline struct Metadata: Hashable { + @usableFromInline var dataType: PostgresDataType + @usableFromInline var format: PostgresFormat + @inlinable init(dataType: PostgresDataType, format: PostgresFormat) { self.dataType = dataType self.format = format } + @inlinable init(value: Value) { self.init(dataType: Value.psqlType, format: Value.psqlFormat) } } + @usableFromInline var metadata: [Metadata] + @usableFromInline var bytes: ByteBuffer - var count: Int { + public var count: Int { self.metadata.count } - init() { + public init() { self.metadata = [] self.bytes = ByteBuffer() } - init(capacity: Int) { + public init(capacity: Int) { self.metadata = [] self.metadata.reserveCapacity(capacity) self.bytes = ByteBuffer() self.bytes.reserveCapacity(128 * capacity) } - mutating func appendNull() { + public mutating func appendNull() { self.bytes.writeInteger(-1, as: Int32.self) self.metadata.append(.init(dataType: .null, format: .binary)) } + @inlinable + public mutating func append( + _ value: Value, + context: PostgresEncodingContext + ) throws { + try value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value)) + } + mutating func append(_ postgresData: PostgresData) { switch postgresData.value { case .none: @@ -124,12 +144,4 @@ struct PostgresBindings: Hashable { } self.metadata.append(.init(dataType: postgresData.type, format: .binary)) } - - mutating func append( - _ value: Value, - context: PostgresEncodingContext - ) throws { - try value.encodeRaw(into: &self.bytes, context: context) - self.metadata.append(.init(value: value)) - } } diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 691c334f..afb9f590 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -1,6 +1,6 @@ import Logging import XCTest -@testable import PostgresNIO +import PostgresNIO #if swift(>=5.5.2) final class AsyncPostgresConnectionTests: XCTestCase { From c7edb9b71e5055a0c5d918b91a44a88498b997c4 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 11 Mar 2022 09:30:29 +0100 Subject: [PATCH 085/292] Rename PostgresBackendMessageDecoder (#252) --- Sources/PostgresNIO/New/PSQLChannelHandler.swift | 6 +++--- ...r.swift => PostgresBackendMessageDecoder.swift} | 2 +- .../New/Messages/AuthenticationTests.swift | 2 +- .../New/Messages/BackendKeyDataTests.swift | 4 ++-- .../New/Messages/DataRowTests.swift | 2 +- .../New/Messages/ErrorResponseTests.swift | 2 +- .../New/Messages/NotificationResponseTests.swift | 6 +++--- .../New/Messages/ParameterDescriptionTests.swift | 6 +++--- .../New/Messages/ParameterStatusTests.swift | 6 +++--- .../New/Messages/ReadyForQueryTests.swift | 6 +++--- .../New/Messages/RowDescriptionTests.swift | 10 +++++----- .../New/PSQLBackendMessageTests.swift | 14 +++++++------- 12 files changed, 33 insertions(+), 33 deletions(-) rename Sources/PostgresNIO/New/{PSQLBackendMessageDecoder.swift => PostgresBackendMessageDecoder.swift} (99%) diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index d6dcd253..ff2bdc44 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -24,7 +24,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { /// The context is captured in `handlerAdded` and released` in `handlerRemoved` private var handlerContext: ChannelHandlerContext! private var rowStream: PSQLRowStream? - private var decoder: NIOSingleStepByteToMessageProcessor + private var decoder: NIOSingleStepByteToMessageProcessor private var encoder: BufferedMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? @@ -40,7 +40,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger - self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder()) + self.decoder = NIOSingleStepByteToMessageProcessor(PostgresBackendMessageDecoder()) } #if DEBUG @@ -54,7 +54,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger - self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder()) + self.decoder = NIOSingleStepByteToMessageProcessor(PostgresBackendMessageDecoder()) } #endif diff --git a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift similarity index 99% rename from Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift rename to Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift index 9a3d6628..e8487fb6 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -1,4 +1,4 @@ -struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { +struct PostgresBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { typealias InboundOut = PostgresBackendMessage private(set) var hasAlreadyReceivedBytes: Bool diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 85a4314f..31a21a91 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -40,6 +40,6 @@ class AuthenticationTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } } diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift index 2db8493b..b67145c2 100644 --- a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -16,7 +16,7 @@ class BackendKeyDataTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expectedInOuts, - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } func testDecodeInvalidLength() { @@ -32,7 +32,7 @@ class BackendKeyDataTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expected, - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { XCTAssert($0 is PSQLDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index 660baa92..db31b98a 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -28,7 +28,7 @@ class DataRowTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expectedInOuts, - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } func testIteratingElements() { diff --git a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift index 038ec34c..80015ea0 100644 --- a/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ErrorResponseTests.swift @@ -30,6 +30,6 @@ class ErrorResponseTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expectedInOuts, - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } } diff --git a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift index f41a74af..7928e3f8 100644 --- a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift @@ -27,7 +27,7 @@ class NotificationResponseTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } func testDecodeFailureBecauseOfMissingNullTermination() { @@ -40,7 +40,7 @@ class NotificationResponseTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -55,7 +55,7 @@ class NotificationResponseTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift index 5c3ff150..dd42aea4 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift @@ -27,7 +27,7 @@ class ParameterDescriptionTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } func testDecodeWithNegativeCount() { @@ -43,7 +43,7 @@ class ParameterDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -62,7 +62,7 @@ class ParameterDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift index a84e2ac4..ca4aa942 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift @@ -42,7 +42,7 @@ class ParameterStatusTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } func testDecodeFailureBecauseOfMissingNullTermination() { @@ -54,7 +54,7 @@ class ParameterStatusTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -68,7 +68,7 @@ class ParameterStatusTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift index 8ece1bfc..e915be72 100644 --- a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift @@ -33,7 +33,7 @@ class ReadyForQueryTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } @@ -47,7 +47,7 @@ class ReadyForQueryTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -61,7 +61,7 @@ class ReadyForQueryTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index 7e941d54..899c88f1 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -38,7 +38,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) } func testDecodeFailureBecauseOfMissingNullTerminationInColumnName() { @@ -59,7 +59,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -81,7 +81,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -104,7 +104,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -127,7 +127,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { XCTAssert($0 is PSQLDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index 60209d2b..d55e86bc 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -97,7 +97,7 @@ class PSQLBackendMessageTests: XCTestCase { expectedMessages.append(.parameterStatus(parameterStatus)) } - let handler = ByteToMessageHandler(PSQLBackendMessageDecoder()) + let handler = ByteToMessageHandler(PostgresBackendMessageDecoder()) let embedded = EmbeddedChannel(handler: handler) XCTAssertNoThrow(try embedded.writeInbound(buffer)) @@ -137,7 +137,7 @@ class PSQLBackendMessageTests: XCTestCase { buffer.writeInteger(0, as: UInt8.self) // signal done } - let handler = ByteToMessageHandler(PSQLBackendMessageDecoder()) + let handler = ByteToMessageHandler(PostgresBackendMessageDecoder()) let embedded = EmbeddedChannel(handler: handler) XCTAssertNoThrow(try embedded.writeInbound(buffer)) @@ -174,7 +174,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) } func testPayloadsWithoutAssociatedValuesInvalidLength() { @@ -195,7 +195,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -222,7 +222,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(okBuffer, expected)], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) // test commandTag is not null terminated for message in expected { @@ -237,7 +237,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(failBuffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { XCTAssert($0 is PSQLDecodingError) } } @@ -250,7 +250,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], - decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { XCTAssert($0 is PSQLDecodingError) } } From 1cd8d366cf17f9b64d5fc520f57043c48965eeae Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 11 Mar 2022 09:47:35 +0100 Subject: [PATCH 086/292] Add support for Network.framework (#253) --- Package.swift | 2 ++ .../Connection/PostgresConnection.swift | 25 +++++++++++++++++-- Tests/IntegrationTests/AsyncTests.swift | 25 +++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/Package.swift b/Package.swift index 510c04fe..2dacd63f 100644 --- a/Package.swift +++ b/Package.swift @@ -14,6 +14,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.35.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), @@ -27,6 +28,7 @@ let package = Package( .product(name: "NIO", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOFoundationCompat", package: "swift-nio"), diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 390edcce..4575cd28 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -1,5 +1,8 @@ import NIOCore import NIOConcurrencyHelpers +#if canImport(Network) +import NIOTransportServices +#endif import NIOSSL import Logging import NIOPosix @@ -249,12 +252,13 @@ public final class PostgresConnection { // thread and the EventLoop. return eventLoop.flatSubmit { () -> EventLoopFuture in let connectFuture: EventLoopFuture + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) switch configuration.connection { case .resolved(let address, _): - connectFuture = ClientBootstrap(group: eventLoop).connect(to: address) + connectFuture = bootstrap.connect(to: address) case .unresolved(let host, let port): - connectFuture = ClientBootstrap(group: eventLoop).connect(host: host, port: port) + connectFuture = bootstrap.connect(host: host, port: port) } return connectFuture.flatMap { channel -> EventLoopFuture in @@ -271,6 +275,23 @@ public final class PostgresConnection { } } + static func makeBootstrap( + on eventLoop: EventLoop, + configuration: PostgresConnection.InternalConfiguration + ) -> NIOClientTCPBootstrapProtocol { + #if canImport(Network) + if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { + return tsBootstrap + } + #endif + + if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { + return nioBootstrap + } + + fatalError("No matching bootstrap found") + } + // MARK: Query func query(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture { diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index afb9f590..d28a9e62 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -1,6 +1,9 @@ import Logging import XCTest import PostgresNIO +#if canImport(Network) +import NIOTransportServices +#endif #if swift(>=5.5.2) final class AsyncPostgresConnectionTests: XCTestCase { @@ -41,6 +44,28 @@ final class AsyncPostgresConnectionTests: XCTestCase { XCTAssertEqual(counter, end + 1) } } + + #if canImport(Network) + func testSelect10kRowsNetworkFramework() async throws { + let eventLoopGroup = NIOTSEventLoopGroup() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 1 + for try await element in rows.decode(Int.self, context: .default) { + XCTAssertEqual(element, counter) + counter += 1 + } + + XCTAssertEqual(counter, end + 1) + } + } + #endif } extension XCTestCase { From 8e341c1e546b95acadf306eadaf0230a096988a8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 11 Mar 2022 10:12:38 +0100 Subject: [PATCH 087/292] Make DataRow Sendable (#250) --- Sources/PostgresNIO/New/Messages/DataRow.swift | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index b5c3f8e7..0deb0043 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -1,4 +1,8 @@ +#if swift(>=5.6) +@preconcurrency import NIOCore +#else import NIOCore +#endif /// A backend data row message. /// @@ -116,3 +120,7 @@ extension DataRow { return self[byteIndex] } } + +#if swift(>=5.6) +extension DataRow: Sendable {} +#endif From 7f290f2703dd1e6cd0d777fbfe3f0ef76511d045 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 11 Mar 2022 10:52:26 +0100 Subject: [PATCH 088/292] Rename PostgresChannelHandler (#251) --- .../PostgresNIO/Connection/PostgresConnection.swift | 4 ++-- Sources/PostgresNIO/New/PSQLEventsHandler.swift | 6 +++--- ...nelHandler.swift => PostgresChannelHandler.swift} | 6 +++--- ...Tests.swift => PostgresChannelHandlerTests.swift} | 12 ++++++------ 4 files changed, 14 insertions(+), 14 deletions(-) rename Sources/PostgresNIO/New/{PSQLChannelHandler.swift => PostgresChannelHandler.swift} (99%) rename Tests/PostgresNIOTests/New/{PSQLChannelHandlerTests.swift => PostgresChannelHandlerTests.swift} (94%) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 4575cd28..ad3d14e7 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -173,7 +173,7 @@ public final class PostgresConnection { } } - let channelHandler = PSQLChannelHandler( + let channelHandler = PostgresChannelHandler( configuration: configuration, logger: logger, configureSSLCallback: configureSSLCallback @@ -597,7 +597,7 @@ extension PostgresConnection { let listenContext = PostgresListenContext() - self.channel.pipeline.handler(type: PSQLChannelHandler.self).whenSuccess { handler in + self.channel.pipeline.handler(type: PostgresChannelHandler.self).whenSuccess { handler in if self.notificationListeners[channel] != nil { self.notificationListeners[channel]!.append((listenContext, notificationHandler)) } diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift index 0318061e..3233fb77 100644 --- a/Sources/PostgresNIO/New/PSQLEventsHandler.swift +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -3,7 +3,7 @@ import NIOTLS import Logging enum PSQLOutgoingEvent { - /// the event we send down the channel to inform the `PSQLChannelHandler` to authenticate + /// the event we send down the channel to inform the ``PostgresChannelHandler`` to authenticate /// /// this shall be removed with the next breaking change and always supplied with `PSQLConnection.Configuration` case authenticate(AuthContext) @@ -11,10 +11,10 @@ enum PSQLOutgoingEvent { enum PSQLEvent { - /// the event that is used to inform upstream handlers that `PSQLChannelHandler` has established a connection + /// the event that is used to inform upstream handlers that ``PostgresChannelHandler`` has established a connection case readyForStartup - /// the event that is used to inform upstream handlers that `PSQLChannelHandler` is currently idle + /// the event that is used to inform upstream handlers that ``PostgresChannelHandler`` is currently idle case readyForQuery } diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift similarity index 99% rename from Sources/PostgresNIO/New/PSQLChannelHandler.swift rename to Sources/PostgresNIO/New/PostgresChannelHandler.swift index ff2bdc44..55d7aff1 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -7,7 +7,7 @@ protocol PSQLChannelHandlerNotificationDelegate: AnyObject { func notificationReceived(_: PostgresBackendMessage.NotificationResponse) } -final class PSQLChannelHandler: ChannelDuplexHandler { +final class PostgresChannelHandler: ChannelDuplexHandler { typealias OutboundIn = PSQLTask typealias InboundIn = ByteBuffer typealias OutboundOut = ByteBuffer @@ -501,7 +501,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { } } -extension PSQLChannelHandler: PSQLRowsDataSource { +extension PostgresChannelHandler: PSQLRowsDataSource { func request(for stream: PSQLRowStream) { guard self.rowStream === stream else { return @@ -587,7 +587,7 @@ extension ConnectionStateMachine.TLSConfiguration { } } -extension PSQLChannelHandler { +extension PostgresChannelHandler { convenience init( configuration: PostgresConnection.InternalConfiguration, configureSSLCallback: ((Channel) throws -> Void)?) diff --git a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift similarity index 94% rename from Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift rename to Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 36eac812..d3c2b10f 100644 --- a/Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -5,13 +5,13 @@ import NIOSSL import NIOEmbedded @testable import PostgresNIO -class PSQLChannelHandlerTests: XCTestCase { +class PostgresChannelHandlerTests: XCTestCase { // MARK: Startup func testHandlerAddedWithoutSSL() { let config = self.testConnectionConfiguration() - let handler = PSQLChannelHandler(configuration: config, configureSSLCallback: nil) + let handler = PostgresChannelHandler(configuration: config, configureSSLCallback: nil) let embedded = EmbeddedChannel(handlers: [ ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), @@ -40,7 +40,7 @@ class PSQLChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PSQLChannelHandler(configuration: config) { channel in + let handler = PostgresChannelHandler(configuration: config) { channel in addSSLCallbackIsHit = true } let embedded = EmbeddedChannel(handlers: [ @@ -82,7 +82,7 @@ class PSQLChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) - let handler = PSQLChannelHandler(configuration: config) { channel in + let handler = PostgresChannelHandler(configuration: config) { channel in XCTFail("This callback should never be exectuded") throw PSQLError.sslUnsupported } @@ -118,7 +118,7 @@ class PSQLChannelHandlerTests: XCTestCase { database: config.authentication?.database ) let state = ConnectionStateMachine(.waitingToStartAuthentication) - let handler = PSQLChannelHandler(configuration: config, state: state, configureSSLCallback: nil) + let handler = PostgresChannelHandler(configuration: config, state: state, configureSSLCallback: nil) let embedded = EmbeddedChannel(handlers: [ ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), @@ -147,7 +147,7 @@ class PSQLChannelHandlerTests: XCTestCase { database: config.authentication?.database ) let state = ConnectionStateMachine(.waitingToStartAuthentication) - let handler = PSQLChannelHandler(configuration: config, state: state, configureSSLCallback: nil) + let handler = PostgresChannelHandler(configuration: config, state: state, configureSSLCallback: nil) let embedded = EmbeddedChannel(handlers: [ ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), From 93a928d984fe247b52579e8175c824a519598b08 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 11 Mar 2022 11:58:48 +0100 Subject: [PATCH 089/292] README update for async/await (#254) Co-authored-by: Gwynne Raskind --- README.md | 289 +++++++++++++++++++++++++----------------------------- 1 file changed, 133 insertions(+), 156 deletions(-) diff --git a/README.md b/README.md index 79dfc669..85c763cd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ PostgresNIO -
- + + + SSWG Incubating + + Documentation @@ -18,211 +21,185 @@

-🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO](https://github.com/apple/swift-nio). - -### Major Releases - -The table below shows a list of PostgresNIO major releases alongside their compatible NIO and Swift versions. - -|Version|NIO|Swift|SPM| -|-|-|-|-| -|1.0|2.0+|5.2+|`from: "1.0.0"`| - -Use the SPM string to easily include the dependendency in your `Package.swift` file. - -```swift -.package(url: "https://github.com/vapor/postgres-nio.git", from: ...) -``` - -### Supported Platforms - -PostgresNIO supports the following platforms: - -- Ubuntu 16.04+ -- macOS 10.15+ - -### Security - -Please see [SECURITY.md](https://github.com/vapor/.github/blob/main/SECURITY.md) for details on the security process. - -## Overview - -PostgresNIO is a client package for connecting to, authorizing, and querying a PostgreSQL server. At the heart of this module are NIO channel handlers for parsing and serializing messages in PostgreSQL's proprietary wire protocol. These channel handlers are combined in a request / response style connection type that provides a convenient, client-like interface for performing queries. - -Support for both simple (text) and parameterized (binary) querying is provided out of the box alongside a `PostgresData` type that handles conversion between PostgreSQL's wire format and native Swift types. - -### Motivation - -Most Swift implementations of Postgres clients are based on the [libpq](https://www.postgresql.org/docs/11/libpq.html) C library which handles transport internally. Building a library directly on top of Postgres' wire protocol using SwiftNIO should yield a more reliable, maintainable, and performant interface for PostgreSQL databases. - -### Goals - -This package is meant to be a low-level, unopinionated PostgreSQL wire-protocol implementation for Swift. The hope is that higher level packages can share PostgresNIO as a foundation for interacting with PostgreSQL servers without needing to duplicate complex logic. - -Because of this, PostgresNIO excludes some important concepts for the sake of simplicity, such as: - -- Connection pooling -- Swift `Codable` integration -- Query building - -If you are looking for a PostgreSQL client package to use in your project, take a look at these higher-level packages built on top of PostgresNIO: - -- [`vapor/postgres-kit`](https://github.com/vapor/postgresql) - -### Dependencies +🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. -This package has four dependencies: +Features: -- [`apple/swift-nio`](https://github.com/apple/swift-nio) for IO -- [`apple/swift-nio-ssl`](https://github.com/apple/swift-nio-ssl) for TLS -- [`apple/swift-log`](https://github.com/apple/swift-log) for logging -- [`apple/swift-metrics`](https://github.com/apple/swift-metrics) for metrics +- A `PostgresConnection` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server +- An async/await interface that supports backpressure +- Automatic conversions between Swift primitive types and the Postgres wire format +- Integrated with the Swift server ecosystem, including use of [SwiftLog]. +- Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) +- Support for `Network.framework` when available (e.g. on Apple platforms) -This package has no additional system dependencies. +PostgresNIO does not have a `ConnectionPool` as of today, but this is a feature high on our list. If +you need a `ConnectionPool` today, please have a look at Vapor's [PostgresKit]. ## API Docs -Check out the [PostgresNIO API docs](https://api.vapor.codes/postgres-nio/main/PostgresNIO/) for a detailed look at all of the classes, structs, protocols, and more. +Check out the [PostgresNIO API docs](https://api.vapor.codes/postgres-nio/main/PostgresNIO/) for a +detailed look at all of the classes, structs, protocols, and more. -## Getting Started +## Getting started -This section will provide a quick look at using PostgresNIO. +#### Adding the dependency -### Creating a Connection - -The first step to making a query is creating a new `PostgresConnection`. The minimum requirements to create one are a `SocketAddress` and `EventLoop`. +Add `PostgresNIO` as dependency to your `Package.swift`: ```swift -import PostgresNIO - -let eventLoop: EventLoop = ... -let conn = try PostgresConnection.connect( - to: .makeAddressResolvingHost("my.psql.server", port: 5432), - on: eventLoop -).wait() -defer { try! conn.close().wait() } + dependencies: [ + .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.8.0"), + ... + ] ``` -Note: These examples will make use of `wait()` for simplicity. This is appropriate if you are using PostgresNIO on the main thread, like for a CLI tool or in tests. However, you should never use `wait()` on an event loop. - -There are a few ways to create a `SocketAddress`: - -- `init(ipAddress: String, port: Int)` -- `init(unixDomainSocketPath: String)` -- `makeAddressResolvingHost(_ host: String, port: Int)` - -There are also some additional arguments you can supply to `connect`. - -- `tlsConfiguration` An optional `TLSConfiguration` struct. If supplied, the PostgreSQL connection will be upgraded to use SSL. -- `serverHostname` An optional `String` to use in conjunction with `tlsConfiguration` to specify the server's hostname. - -`connect` will return a future `PostgresConnection`, or an error if it could not connect. Make sure you close the connection before it deinitializes. - -### Authentication - -Once you have a connection, you will need to authenticate with the server using the `authenticate` method. - +Add `PostgresNIO` to the target you want to use it in: ```swift -try conn.authenticate( - username: "your_username", - database: "your_database", - password: "your_password" -).wait() + targets: [ + .target(name: "MyFancyTarget", dependencies: [ + .product(name: "PostgresNIO", package: "postgres-nio"), + ]) + ] ``` -This requires a username. You may supply a database name and password if needed. +#### Creating a connection -### Database Protocol - -Interaction with a server revolves around the `PostgresDatabase` protocol. This protocol includes methods like `query(_:)` for executing SQL queries and reading the resulting rows. - -`PostgresConnection` is the default implementation of `PostgresDatabase` provided by this package. Assume `db` here is the connection from the previous example. +To create a connection, first create a connection configuration object: ```swift import PostgresNIO -let db: PostgresDatabase = ... -// now we can use client to do queries +let config = PostgresConnection.Configuration( + connection: .init( + host: "localhost", + port: 5432 + ), + authentication: .init( + username: "my_username", + database: "my_database", + password: "my_password" + ), + tls: .disable +) ``` -### Simple Query +A connection must be created on a SwiftNIO `EventLoop`. In most server use cases, an +`EventLoopGroup` is created at app startup and closed during app shutdown. + +```swift +import NIOPosix -Simple (or text) queries allow you to execute a SQL string on the connected PostgreSQL server. These queries do not support binding parameters, so any values sent must be escaped manually. +let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) -These queries are most useful for schema or transactional queries, or simple selects. Note that values returned by simple queries will be transferred in the less efficient text format. +// Much later +try eventLoopGroup.syncShutdown() +``` -`simpleQuery` has two overloads, one that returns an array of rows, and one that accepts a closure for handling each row as it is returned. +A `Logger` is also required. ```swift -let rows = try db.simpleQuery("SELECT version()").wait() -print(rows) // [["version": "12.x.x"]] +import Logging -try db.simpleQuery("SELECT version()") { row in - print(row) // ["version": "12.x.x"] -}.wait() +let logger = Logger(label: "postgres-logger") ``` -### Parameterized Query +Now we can put it together: -Parameterized (or binary) queries allow you to execute a SQL string on the connected PostgreSQL server. These queries support passing bound parameters as a separate argument. Each parameter is represented in the SQL string using incrementing placeholders, starting at `$1`. +```swift +import PostgresNIO +import NIOPosix +import Logging + +let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +let logger = Logger(label: "postgres-logger") + +let config = PostgresConnection.Configuration( + connection: .init( + host: "localhost", + port: 5432 + ), + authentication: .init( + username: "my_username", + database: "my_database", + password: "my_password" + ), + tls: .disable +) + +let connection = try await PostgresConnection.connect( + on eventLoop: eventLoopGroup.next(), + configuration: config, + id connectionID: 1, + logger: logger +) + +// Close your connection once done +try await connection.close() + +// Shutdown the EventLoopGroup, once all connections are closed. +try eventLoopGroup.syncShutdown() +``` -These queries are most useful for selecting, inserting, and updating data. Data for these queries is transferred using the highly efficient binary format. +#### Querying -Just like `simpleQuery`, `query` also offers two overloads. One that returns an array of rows, and one that accepts a closure for handling each row as it is returned. +Once a connection is established, queries can be sent to the server. This is very straightforward: ```swift -let rows = try db.query("SELECT * FROM planets WHERE name = $1", ["Earth"]).wait() -print(rows) // [["id": 42, "name": "Earth"]] - -try db.query("SELECT * FROM planets WHERE name = $1", ["Earth"]) { row in - print(row) // ["id": 42, "name": "Earth"] -}.wait() +let rows = try await connection.query("SELECT id, username, birthday FROM users", logger: logger) ``` -### Rows and Data - -Both `simpleQuery` and `query` return the same `PostgresRow` type. Columns can be fetched from the row using the `column(_: String)` method. +The query will return a [`PostgresRowSequence`], which is an AsyncSequence of [`PostgresRow`]s. The rows can be iterated one-by-one: ```swift -let row: PostgresRow = ... -let version = row.column("version") -print(version) // PostgresData? +for try await row in rows { + // do something with the row +} ``` -`PostgresRow` columns are stored as `PostgresData`. This struct contains the raw bytes returned by PostgreSQL as well as some information for parsing them, such as: +#### Decoding from PostgresRow -- Postgres column type -- Wire format: binary or text -- Value as array of bytes - -`PostgresData` has a variety of convenience methods for converting column data to usable Swift types. +However, in most cases it is much easier to request a row's fields as a set of Swift types: ```swift -let data: PostgresData= ... +for try await (id, username, birthday) in rows.decode((Int, String, Date).self, context: .default) { + // do something with the datatypes. +} +``` -print(data.string) // String? +A type must implement the `PostgresDecodable` protocol in order to be decoded from a row. PostgresNIO provides default implementations for most of Swift's builtin types, as well as some types provided by Foundation: -// Postgres only supports signed Ints. -print(data.int) // Int? -print(data.int16) // Int16? -print(data.int32) // Int32? -print(data.int64) // Int64? +- `Bool` +- `Bytes`, `Data`, `ByteBuffer` +- `Date` +- `UInt8`, `Int16`, `Int32`, `Int64`, `Int` +- `Float`, `Double` +- `String` +- `UUID` -// 'char' can be interpreted as a UInt8. -// It will show in db as a character though. -print(data.uint8) // UInt8? +#### Querying with parameters -print(data.bool) // Bool? +Sending parameterized queries to the database is also supported (in the coolest way possible): -print(try data.jsonb(as: Foo.self)) // Foo? +```swift +let id = 1 +let username = "fancyuser" +let birthday = Date() +try await connection.query(""" + INSERT INTO users (id, username, birthday) VALUES (\(id), \(username), \(birthday)) + """, + logger: logger +) +``` -print(data.float) // Float? -print(data.double) // Double? +While this looks at first glance like a classic case of [SQL injection](https://en.wikipedia.org/wiki/SQL_injection) 😱, PostgresNIO's API ensures that this usage is safe. The first parameter of the `query(_:logger:)` method is not a plain `String`, but a `PostgresQuery`, which implements Swift's `ExpressibleByStringInterpolation` protocol. PostgresNIO uses the literal parts of the provided string as the SQL query and replaces each interpolated value with a parameter binding. Only values which implement the `PostgresEncodable` protocol may be interpolated in this way. As with `PostgresDecodable`, PostgresNIO provides default implementations for most common types. -print(data.date) // Date? -print(data.uuid) // UUID? +Some queries do not receive any rows from the server (most often `INSERT`, `UPDATE`, and `DELETE` queries with no `RETURNING` clause, not to mention most DDL queries). To support this, the `query(_:logger:)` method is marked `@discardableResult`, so that the compiler does not issue a warning if the return value is not used. -print(data.numeric) // PostgresNumeric? -``` +## Security + +Please see [SECURITY.md](https://github.com/vapor/.github/blob/main/SECURITY.md) for details on the security process. -`PostgresData` is also used for sending data _to_ the server via parameterized values. To create `PostgresData` from a Swift type, use the available intializer methods. +[EventLoopGroupConnectionPool]: https://github.com/vapor/async-kit/blob/main/Sources/AsyncKit/ConnectionPool/EventLoopGroupConnectionPool.swift +[AsyncKit]: https://github.com/vapor/async-kit/ +[SwiftNIO]: https://github.com/apple/swift-nio +[SwiftLog]: https://github.com/apple/swift-log From 57fda42d42176ee5309e7163f3a36dc20f626685 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 11 Mar 2022 15:32:10 +0100 Subject: [PATCH 090/292] Fix missing links (#255) --- README.md | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 85c763cd..6cc0d158 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Features: -- A `PostgresConnection` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server +- A [`PostgresConnection`] which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server - An async/await interface that supports backpressure - Automatic conversions between Swift primitive types and the Postgres wire format - Integrated with the Swift server ecosystem, including use of [SwiftLog]. @@ -95,7 +95,7 @@ let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) try eventLoopGroup.syncShutdown() ``` -A `Logger` is also required. +A [`Logger`] is also required. ```swift import Logging @@ -166,7 +166,7 @@ for try await (id, username, birthday) in rows.decode((Int, String, Date).self, } ``` -A type must implement the `PostgresDecodable` protocol in order to be decoded from a row. PostgresNIO provides default implementations for most of Swift's builtin types, as well as some types provided by Foundation: +A type must implement the [`PostgresDecodable`] protocol in order to be decoded from a row. PostgresNIO provides default implementations for most of Swift's builtin types, as well as some types provided by Foundation: - `Bool` - `Bytes`, `Data`, `ByteBuffer` @@ -191,15 +191,24 @@ try await connection.query(""" ) ``` -While this looks at first glance like a classic case of [SQL injection](https://en.wikipedia.org/wiki/SQL_injection) 😱, PostgresNIO's API ensures that this usage is safe. The first parameter of the `query(_:logger:)` method is not a plain `String`, but a `PostgresQuery`, which implements Swift's `ExpressibleByStringInterpolation` protocol. PostgresNIO uses the literal parts of the provided string as the SQL query and replaces each interpolated value with a parameter binding. Only values which implement the `PostgresEncodable` protocol may be interpolated in this way. As with `PostgresDecodable`, PostgresNIO provides default implementations for most common types. +While this looks at first glance like a classic case of [SQL injection](https://en.wikipedia.org/wiki/SQL_injection) 😱, PostgresNIO's API ensures that this usage is safe. The first parameter of the [`query(_:logger:)`] method is not a plain `String`, but a [`PostgresQuery`], which implements Swift's `ExpressibleByStringInterpolation` protocol. PostgresNIO uses the literal parts of the provided string as the SQL query and replaces each interpolated value with a parameter binding. Only values which implement the [`PostgresEncodable`] protocol may be interpolated in this way. As with [`PostgresDecodable`], PostgresNIO provides default implementations for most common types. -Some queries do not receive any rows from the server (most often `INSERT`, `UPDATE`, and `DELETE` queries with no `RETURNING` clause, not to mention most DDL queries). To support this, the `query(_:logger:)` method is marked `@discardableResult`, so that the compiler does not issue a warning if the return value is not used. +Some queries do not receive any rows from the server (most often `INSERT`, `UPDATE`, and `DELETE` queries with no `RETURNING` clause, not to mention most DDL queries). To support this, the [`query(_:logger:)`] method is marked `@discardableResult`, so that the compiler does not issue a warning if the return value is not used. ## Security Please see [SECURITY.md](https://github.com/vapor/.github/blob/main/SECURITY.md) for details on the security process. -[EventLoopGroupConnectionPool]: https://github.com/vapor/async-kit/blob/main/Sources/AsyncKit/ConnectionPool/EventLoopGroupConnectionPool.swift -[AsyncKit]: https://github.com/vapor/async-kit/ +[`PostgresConnection`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresConnection/ +[`query(_:logger:)`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresConnection/#postgresconnection.query(_:logger:file:line:) +[`PostgresQuery`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresQuery/ +[`PostgresRow`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresRow/ +[`PostgresRowSequence`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresRowSequence/ +[`PostgresDecodable`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresDecodable/ +[`PostgresEncodable`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresEncodable/ + +[PostgresKit]: https://github.com/vapor/postgres-kit + [SwiftNIO]: https://github.com/apple/swift-nio [SwiftLog]: https://github.com/apple/swift-log +[`Logger`]: https://apple.github.io/swift-log/docs/current/Logging/Structs/Logger.html From b1a9a328438871f9224236caed7cc7dffc0d54b1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 14 Mar 2022 13:36:39 +0100 Subject: [PATCH 091/292] Fix CI test link in README (#260) Co-authored-by: Gwynne Raskind --- .github/workflows/main-codecov.yml | 23 +++++++---------------- README.md | 5 ++--- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/.github/workflows/main-codecov.yml b/.github/workflows/main-codecov.yml index 7a55c3ae..85a794f1 100644 --- a/.github/workflows/main-codecov.yml +++ b/.github/workflows/main-codecov.yml @@ -1,22 +1,13 @@ -name: main codecov +name: CI for main on: push: branches: - main jobs: update-main-codecov: - runs-on: ubuntu-latest - container: swift:5.5-focal - steps: - - name: Check out main - uses: actions/checkout@v2 - - name: Run unit tests with code coverage and Thread Sanitizer - run: swift test --enable-code-coverage --sanitize=thread --filter=^PostgresNIOTests - - name: Submit coverage report to Codecov.io - uses: vapor/swift-codecov-action@v0.1.1 - with: - cc_flags: 'unittests' - cc_env_vars: 'SWIFT_VERSION,SWIFT_PLATFORM,RUNNER_OS,RUNNER_ARCH' - cc_fail_ci_if_error: true - cc_verbose: true - cc_dry_run: false + uses: vapor/ci/.github/workflows/run-unit-tests.yml@reusable-workflows + with: + with_coverage: true + with_tsan: true + coverage_ignores: '/Tests/' + test_filter: '^PostgresNIOTests' diff --git a/README.md b/README.md index 6cc0d158..e558e046 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ MIT License
- Continuous Integration + Continuous Integration Swift 5.2 @@ -32,8 +32,7 @@ Features: - Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) - Support for `Network.framework` when available (e.g. on Apple platforms) -PostgresNIO does not have a `ConnectionPool` as of today, but this is a feature high on our list. If -you need a `ConnectionPool` today, please have a look at Vapor's [PostgresKit]. +PostgresNIO does not provide a `ConnectionPool` as of today, but this is a [feature high on our list](https://github.com/vapor/postgres-nio/issues/256). If you need a `ConnectionPool` today, please have a look at Vapor's [PostgresKit]. ## API Docs From f11cc9bd30164079a58306e5085e1908b61e80db Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 16 Mar 2022 07:43:34 +0100 Subject: [PATCH 092/292] Allow unescaped SQL interpolation in PostgresQuery (#258) --- Sources/PostgresNIO/New/PostgresQuery.swift | 5 +++ .../New/PostgresQueryTests.swift | 44 ++++++++++++------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 5bb33988..276d969f 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -67,6 +67,11 @@ extension PostgresQuery { try self.binds.append(value, context: context) self.sql.append(contentsOf: "$\(self.binds.count)") } + + @inlinable + public mutating func appendInterpolation(unescaped interpolated: String) { + self.sql.append(contentsOf: interpolated) + } } } diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift index 43c39a3a..68fd8b9b 100644 --- a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -3,16 +3,18 @@ import XCTest final class PostgresQueryTests: XCTestCase { - func testStringInterpolationWithOptional() throws { + func testStringInterpolationWithOptional() { let string = "Hello World" let null: UUID? = nil let uuid: UUID? = UUID() - let query: PostgresQuery = try """ + var query: PostgresQuery? + XCTAssertNoThrow(query = try """ INSERT INTO foo (id, title, something) SET (\(uuid), \(string), \(null)); """ + ) - XCTAssertEqual(query.sql, "INSERT INTO foo (id, title, something) SET ($1, $2, $3);") + XCTAssertEqual(query?.sql, "INSERT INTO foo (id, title, something) SET ($1, $2, $3);") var expected = ByteBuffer() expected.writeInteger(Int32(16)) @@ -27,10 +29,10 @@ final class PostgresQueryTests: XCTestCase { expected.writeString(string) expected.writeInteger(Int32(-1)) - XCTAssertEqual(query.binds.bytes, expected) + XCTAssertEqual(query?.binds.bytes, expected) } - func testStringInterpolationWithCustomJSONEncoder() throws { + func testStringInterpolationWithCustomJSONEncoder() { struct Foo: Codable, PostgresCodable { var helloWorld: String } @@ -38,11 +40,13 @@ final class PostgresQueryTests: XCTestCase { let jsonEncoder = JSONEncoder() jsonEncoder.keyEncodingStrategy = .convertToSnakeCase - let query: PostgresQuery = try """ + var query: PostgresQuery? + XCTAssertNoThrow(query = try """ INSERT INTO test (foo) SET (\(Foo(helloWorld: "bar"), context: .init(jsonEncoder: jsonEncoder))); """ + ) - XCTAssertEqual(query.sql, "INSERT INTO test (foo) SET ($1);") + XCTAssertEqual(query?.sql, "INSERT INTO test (foo) SET ($1);") let expectedJSON = #"{"hello_world":"bar"}"# @@ -51,17 +55,10 @@ final class PostgresQueryTests: XCTestCase { expected.writeInteger(UInt8(0x01)) expected.writeString(expectedJSON) - XCTAssertEqual(query.binds.bytes, expected) + XCTAssertEqual(query?.binds.bytes, expected) } - func testAllowUsersToGenerateLotsOfRows() throws { - struct Foo: Codable, PostgresCodable { - var helloWorld: String - } - - let jsonEncoder = JSONEncoder() - jsonEncoder.keyEncodingStrategy = .convertToSnakeCase - + func testAllowUsersToGenerateLotsOfRows() { let sql = "INSERT INTO test (id) SET (\((1...5).map({"$\($0)"}).joined(separator: ", ")));" var query = PostgresQuery(unsafeSQL: sql, binds: .init(capacity: 5)) @@ -79,4 +76,19 @@ final class PostgresQueryTests: XCTestCase { XCTAssertEqual(query.binds.bytes, expected) } + + func testUnescapedSQL() { + let tableName = UUID().uuidString.uppercased() + let value = 1 + + var query: PostgresQuery? + XCTAssertNoThrow(query = try "INSERT INTO \(unescaped: tableName) (id) SET (\(value));") + XCTAssertEqual(query?.sql, "INSERT INTO \(tableName) (id) SET ($1);") + + var expected = ByteBuffer() + expected.writeInteger(UInt32(8)) + expected.writeInteger(value) + + XCTAssertEqual(query?.binds.bytes, expected) + } } From 33fc9575579cd3a4cb32ec17edb404b505af1ebc Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Mar 2022 10:27:24 +0100 Subject: [PATCH 093/292] Remove state machine log (#266) --- Sources/PostgresNIO/New/PostgresChannelHandler.swift | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 55d7aff1..33c0e3f1 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -13,11 +13,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { typealias OutboundOut = ByteBuffer private let logger: Logger - private var state: ConnectionStateMachine { - didSet { - self.logger.trace("Connection state changed", metadata: [.connectionState: "\(self.state)"]) - } - } + private var state: ConnectionStateMachine /// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed). /// From c72516db641521922d96dd72f0eccc3f2fb47468 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Mar 2022 10:44:39 +0100 Subject: [PATCH 094/292] Use Int(exactly:) instead of restricting to 64 bit platforms (#267) --- .../PostgresNIO/Data/PostgresData+Int.swift | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresData+Int.swift b/Sources/PostgresNIO/Data/PostgresData+Int.swift index ce77dd43..4729021f 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Int.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Int.swift @@ -1,7 +1,6 @@ extension PostgresData { public init(int value: Int) { - assert(Int.bitWidth == 64) - self.init(type: .int8, value: .init(integer: value)) + self.init(type: .int8, value: .init(integer: Int64(value))) } public init(uint8 value: UInt8) { @@ -32,25 +31,19 @@ extension PostgresData { guard value.readableBytes == 1 else { return nil } - return value.readInteger(as: UInt8.self) - .flatMap(Int.init) + return value.readInteger(as: UInt8.self).flatMap(Int.init) case .int2: assert(value.readableBytes == 2) - return value.readInteger(as: Int16.self) - .flatMap(Int.init) + return value.readInteger(as: Int16.self).flatMap(Int.init) case .int4, .regproc: assert(value.readableBytes == 4) - return value.readInteger(as: Int32.self) - .flatMap(Int.init) + return value.readInteger(as: Int32.self).flatMap(Int.init) case .oid: assert(value.readableBytes == 4) - assert(Int.bitWidth == 64) // or else overflow is possible - return value.readInteger(as: UInt32.self) - .flatMap(Int.init) + return value.readInteger(as: UInt32.self).flatMap { Int(exactly: $0) } case .int8: assert(value.readableBytes == 8) - assert(Int.bitWidth == 64) - return value.readInteger(as: Int.self) + return value.readInteger(as: Int64.self).flatMap { Int(exactly: $0) } default: return nil } From ab624e48f73f4c089ac12cdb6efa7d871f7f7824 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 18 Mar 2022 04:53:11 -0500 Subject: [PATCH 095/292] Update CI to 5.6 release and checkout@v3 (#269) --- .github/workflows/test.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 79021623..a9ff286a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: - swift:5.2 - swift:5.3 - swift:5.5 - - swiftlang/swift:nightly-5.6 + - swift:5.6 - swiftlang/swift:nightly-main swiftos: - focal @@ -20,11 +20,11 @@ jobs: LOG_LEVEL: debug steps: - name: Check out package - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Run unit tests with code coverage and Thread Sanitizer run: swift test --enable-test-discovery --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage - name: Submit coverage report to Codecov.io - uses: vapor/swift-codecov-action@v0.1.1 + uses: vapor/swift-codecov-action@v0.2 with: cc_flags: 'unittests' cc_env_vars: 'SWIFT_VERSION,SWIFT_PLATFORM,RUNNER_OS,RUNNER_ARCH' @@ -47,7 +47,7 @@ jobs: dbauth: md5 - dbimage: postgres:11 dbauth: trust - container: swift:5.5-focal + container: swift:5.6-focal runs-on: ubuntu-latest env: LOG_LEVEL: debug @@ -84,15 +84,15 @@ jobs: POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} steps: - name: Check out package - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: { path: 'postgres-nio' } - name: Run integration tests run: swift test --package-path postgres-nio --filter=^IntegrationTests - name: Check out postgres-kit dependent - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } - name: Check out fluent-postgres-driver dependent - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: { repository: 'vapor/fluent-postgres-driver', path: 'fluent-postgres-driver' } - name: Use local package in dependents run: | @@ -137,7 +137,7 @@ jobs: pg_ctl start --wait timeout-minutes: 2 - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Run all tests run: | swift test --enable-test-discovery -Xlinker -rpath \ From c1683ba3111caadf81515ffc0b620654883b410f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Mar 2022 11:09:17 +0100 Subject: [PATCH 096/292] Make forward progress when Query is cancelled (#261) Co-authored-by: Gwynne Raskind --- .../ConnectionStateMachine.swift | 12 ++- .../ExtendedQueryStateMachine.swift | 81 ++++++++++++++++-- .../RowStreamStateMachine.swift | 49 +++++++++++ Sources/PostgresNIO/New/PSQLError.swift | 7 +- .../New/PostgresChannelHandler.swift | 19 +++-- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 2 + Tests/IntegrationTests/AsyncTests.swift | 19 +++++ .../ExtendedQueryStateMachineTests.swift | 85 +++++++++++++++++++ .../ConnectionAction+TestUtils.swift | 2 + 9 files changed, 262 insertions(+), 14 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index fa00328b..13de8281 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -842,7 +842,15 @@ struct ConnectionStateMachine { // MARK: Consumer mutating func cancelQueryStream() -> ConnectionAction { - preconditionFailure("Unimplemented") + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + preconditionFailure("Tried to cancel stream without active query") + } + + return self.avoidingStateMachineCoW { machine -> ConnectionAction in + let action = queryState.cancel() + machine.state = .extendedQuery(queryState, connectionContext) + return machine.modify(with: action) + } } mutating func requestQueryRows() -> ConnectionAction { @@ -1074,6 +1082,8 @@ extension ConnectionStateMachine { return true case .failedToAddSSLHandler: return true + case .queryCancelled: + return false case .server(let message): guard let sqlState = message.fields[.sqlState] else { // any error message that doesn't have a sql state field, is unexpected by default. diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 333742bb..fdde1aa8 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -2,7 +2,7 @@ import NIOCore struct ExtendedQueryStateMachine { - enum State { + private enum State { case initialized(ExtendedQueryContext) case parseDescribeBindExecuteSyncSent(ExtendedQueryContext) @@ -15,6 +15,8 @@ struct ExtendedQueryStateMachine { /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) case streaming([RowDescription.Column], RowStreamStateMachine) + /// Indicates that the current query was cancelled and we want to drain rows from the connection ASAP + case drain([RowDescription.Column]) case commandComplete(commandTag: String) case error(PSQLError) @@ -41,9 +43,11 @@ struct ExtendedQueryStateMachine { case wait } - var state: State + private var state: State + private var isCancelled: Bool init(queryContext: ExtendedQueryContext) { + self.isCancelled = false self.state = .initialized(queryContext) } @@ -71,6 +75,44 @@ struct ExtendedQueryStateMachine { } } } + + mutating func cancel() -> Action { + switch self.state { + case .initialized: + preconditionFailure("Start must be called immediatly after the query was created") + + case .parseDescribeBindExecuteSyncSent(let queryContext), + .parseCompleteReceived(let queryContext), + .parameterDescriptionReceived(let queryContext), + .rowDescriptionReceived(let queryContext, _), + .noDataMessageReceived(let queryContext), + .bindCompleteReceived(let queryContext): + guard !self.isCancelled else { + return .wait + } + + self.isCancelled = true + return .failQuery(queryContext, with: .queryCancelled) + + case .streaming(let columns, var streamStateMachine): + precondition(!self.isCancelled) + self.isCancelled = true + self.state = .drain(columns) + switch streamStateMachine.fail() { + case .wait: + return .forwardStreamError(.queryCancelled, read: false) + case .read: + return .forwardStreamError(.queryCancelled, read: true) + } + + case .commandComplete, .error, .drain: + // the stream has already finished. + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } mutating func parseCompletedReceived() -> Action { guard case .parseDescribeBindExecuteSyncSent(let queryContext) = self.state else { @@ -147,9 +189,11 @@ struct ExtendedQueryStateMachine { .parameterDescriptionReceived, .bindCompleteReceived, .streaming, + .drain, .commandComplete, .error: return self.setAndFireError(.unexpectedBackendMessage(.bindComplete)) + case .modifying: preconditionFailure("Invalid state") } @@ -169,6 +213,13 @@ struct ExtendedQueryStateMachine { state = .streaming(columns, demandStateMachine) return .wait } + + case .drain(let columns): + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + // we ignore all rows and wait for readyForQuery + return .wait case .initialized, .parseDescribeBindExecuteSyncSent, @@ -198,6 +249,11 @@ struct ExtendedQueryStateMachine { state = .commandComplete(commandTag: commandTag) return .forwardStreamComplete(demandStateMachine.end(), commandTag: commandTag) } + + case .drain: + precondition(self.isCancelled) + self.state = .commandComplete(commandTag: commandTag) + return .wait case .initialized, .parseDescribeBindExecuteSyncSent, @@ -229,7 +285,7 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .rowDescriptionReceived, .noDataMessageReceived: return self.setAndFireError(error) - case .streaming: + case .streaming, .drain: return self.setAndFireError(error) case .commandComplete: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) @@ -269,6 +325,9 @@ struct ExtendedQueryStateMachine { } } + case .drain: + return .wait + case .initialized, .parseDescribeBindExecuteSyncSent, .parseCompleteReceived, @@ -291,6 +350,7 @@ struct ExtendedQueryStateMachine { switch self.state { case .initialized, .commandComplete, + .drain, .error, .parseDescribeBindExecuteSyncSent, .parseCompleteReceived, @@ -327,6 +387,7 @@ struct ExtendedQueryStateMachine { .bindCompleteReceived: return .read case .streaming(let columns, var demandStateMachine): + precondition(!self.isCancelled) return self.avoidingStateMachineCoW { state -> Action in let action = demandStateMachine.read() state = .streaming(columns, demandStateMachine) @@ -339,6 +400,7 @@ struct ExtendedQueryStateMachine { } case .initialized, .commandComplete, + .drain, .error: // we already have the complete stream received, now we are waiting for a // `readyForQuery` package. To receive this we need to read! @@ -361,11 +423,20 @@ struct ExtendedQueryStateMachine { .bindCompleteReceived(let context): self.state = .error(error) return .failQuery(context, with: error) - - case .streaming: + + case .drain: self.state = .error(error) return .forwardStreamError(error, read: false) + case .streaming(_, var streamStateMachine): + self.state = .error(error) + switch streamStateMachine.fail() { + case .wait: + return .forwardStreamError(error, read: false) + case .read: + return .forwardStreamError(error, read: true) + } + case .commandComplete, .error: preconditionFailure(""" This state must not be reached. If the query `.isComplete`, the diff --git a/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift index 08953fb2..4bfd5e9b 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift @@ -23,6 +23,8 @@ struct RowStreamStateMachine { /// preserved for performance reasons. case waitingForDemand([DataRow]) + case failed + case modifying } @@ -63,6 +65,11 @@ struct RowStreamStateMachine { buffer.append(newRow) self.state = .waitingForReadOrDemand(buffer) + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + case .modifying: preconditionFailure("Invalid state: \(self.state)") } @@ -86,6 +93,11 @@ struct RowStreamStateMachine { .waitingForReadOrDemand: preconditionFailure("How can we receive a body part, after a channelReadComplete, but no read has been forwarded yet. Invalid state: \(self.state)") + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + case .modifying: preconditionFailure("Invalid state: \(self.state)") } @@ -111,6 +123,11 @@ struct RowStreamStateMachine { // the next `channelReadComplete` we will forward all buffered data return .wait + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + case .modifying: preconditionFailure("Invalid state: \(self.state)") } @@ -136,6 +153,11 @@ struct RowStreamStateMachine { // from the consumer return .wait + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + case .modifying: preconditionFailure("Invalid state: \(self.state)") } @@ -158,6 +180,33 @@ struct RowStreamStateMachine { // receive a call to `end()`, when we don't expect it here. return buffer + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func fail() -> Action { + switch self.state { + case .waitingForRows, + .waitingForReadOrDemand, + .waitingForRead: + self.state = .failed + return .wait + + case .waitingForDemand: + self.state = .failed + return .read + + case .failed: + // Once the row stream state machine is marked as failed, no further events must be + // forwarded to it. + preconditionFailure("Invalid state: \(self.state)") + case .modifying: preconditionFailure("Invalid state: \(self.state)") } diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index a993b538..cb09d12a 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -11,7 +11,8 @@ struct PSQLError: Error { case unsupportedAuthMechanism(PSQLAuthScheme) case authMechanismRequiresPassword case saslError(underlyingError: Error) - + + case queryCancelled case tooManyParameters case connectionQuiescing case connectionClosed @@ -58,6 +59,10 @@ struct PSQLError: Error { static func sasl(underlying: Error) -> PSQLError { Self.init(.saslError(underlyingError: underlying)) } + + static var queryCancelled: PSQLError { + Self.init(.queryCancelled) + } static var tooManyParameters: PSQLError { Self.init(.tooManyParameters) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 33c0e3f1..348a9f21 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -18,7 +18,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { /// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed). /// /// The context is captured in `handlerAdded` and released` in `handlerRemoved` - private var handlerContext: ChannelHandlerContext! + private var handlerContext: ChannelHandlerContext? private var rowStream: PSQLRowStream? private var decoder: NIOSingleStepByteToMessageProcessor private var encoder: BufferedMessageEncoder! @@ -262,7 +262,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .forwardStreamComplete(let buffer, let commandTag): guard let rowStream = self.rowStream else { - preconditionFailure("Expected to have a row stream here.") + // if the stream was cancelled we don't have it here anymore. + return } self.rowStream = nil if buffer.count > 0 { @@ -499,18 +500,20 @@ final class PostgresChannelHandler: ChannelDuplexHandler { extension PostgresChannelHandler: PSQLRowsDataSource { func request(for stream: PSQLRowStream) { - guard self.rowStream === stream else { + guard self.rowStream === stream, let handlerContext = self.handlerContext else { return } let action = self.state.requestQueryRows() - self.run(action, with: self.handlerContext!) + self.run(action, with: handlerContext) } func cancel(for stream: PSQLRowStream) { - guard self.rowStream === stream else { + guard self.rowStream === stream, let handlerContext = self.handlerContext else { return } // we ignore this right now :) + let action = self.state.cancelQueryStream() + self.run(action, with: handlerContext) } } @@ -519,7 +522,8 @@ extension PostgresConnection.Configuration.Authentication { AuthContext( username: self.username, password: self.password, - database: self.database) + database: self.database + ) } } @@ -529,7 +533,8 @@ extension AuthContext { user: self.username, database: self.database, options: nil, - replication: .false) + replication: .false + ) } } diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 8c7e7db1..233c925f 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -3,6 +3,8 @@ import NIOCore extension PSQLError { func toPostgresError() -> Error { switch self.base { + case .queryCancelled: + return self case .server(let errorMessage): var fields = [PostgresMessage.Error.Field: String]() fields.reserveCapacity(errorMessage.fields.count) diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index d28a9e62..cb6950d6 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -45,6 +45,25 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testSelect10times10kRows() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + await withThrowingTaskGroup(of: Void.self) { taskGroup in + for _ in 0..<10 { + taskGroup.addTask { + try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + } + } + } + } + } + #if canImport(Network) func testSelect10kRowsNetworkFramework() async throws { let eventLoopGroup = NIOTSEventLoopGroup() diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index b5055929..bae4c986 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -96,4 +96,89 @@ class ExtendedQueryStateMachineTests: XCTestCase { .failQuery(queryContext, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) } + func testExtendedQueryIsCancelledImmediatly() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test4")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .wait) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testExtendedQueryIsCancelledWithReadPending() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + let row1: DataRow = [ByteBuffer(string: "test1")] + XCTAssertEqual(state.dataRowReceived(row1), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: true, cleanupContext: nil)) + + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test4")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 4"), .wait) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } } diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index 448183b5..fdc69b81 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -36,6 +36,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { return lhsRows == rhsRows case (.forwardStreamComplete(let lhsBuffer, let lhsCommandTag), .forwardStreamComplete(let rhsBuffer, let rhsCommandTag)): return lhsBuffer == rhsBuffer && lhsCommandTag == rhsCommandTag + case (.forwardStreamError(let lhsError, let lhsRead, let lhsCleanupContext), .forwardStreamError(let rhsError , let rhsRead, let rhsCleanupContext)): + return lhsError == rhsError && lhsRead == rhsRead && lhsCleanupContext == rhsCleanupContext case (.sendParseDescribeSync(let lhsName, let lhsQuery), .sendParseDescribeSync(let rhsName, let rhsQuery)): return lhsName == rhsName && lhsQuery == rhsQuery case (.succeedPreparedStatementCreation(let lhsContext, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsContext, let rhsRowDescription)): From e9e431cbb3da260ef39507d7c4f757e96f136820 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Mar 2022 11:21:26 +0100 Subject: [PATCH 097/292] Add EventLoop API that uses PostgresQuery (#265) --- .../Connection/PostgresConnection.swift | 70 ++++++++- .../ConnectionStateMachine.swift | 2 + Sources/PostgresNIO/New/PSQLError.swift | 5 + Sources/PostgresNIO/Postgres+PSQLCompat.swift | 2 +- .../PSQLIntegrationTests.swift | 144 ++++++++---------- 5 files changed, 134 insertions(+), 89 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index ad3d14e7..08b5149e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -294,7 +294,7 @@ public final class PostgresConnection { // MARK: Query - func query(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture { + private func queryStream(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.id)" guard query.binds.count <= Int(Int16.max) else { @@ -433,6 +433,8 @@ extension PostgresConnection { } } +// MARK: Async/Await Interface + #if swift(>=5.5) && canImport(_Concurrency) extension PostgresConnection { @@ -489,7 +491,8 @@ extension PostgresConnection { let context = ExtendedQueryContext( query: query, logger: logger, - promise: promise) + promise: promise + ) self.channel.write(PSQLTask.extendedQuery(context), promise: nil) @@ -498,7 +501,64 @@ extension PostgresConnection { } #endif -// MARK: PostgresDatabase +// MARK: EventLoopFuture interface + +extension PostgresConnection { + + /// Run a query on the Postgres server the connection is connected to and collect all rows. + /// + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: An EventLoopFuture, that allows access to the future ``PostgresQueryResult``. + public func query( + _ query: PostgresQuery, + logger: Logger, + file: String = #file, + line: Int = #line + ) -> EventLoopFuture { + self.queryStream(query, logger: logger).flatMap { rowStream in + rowStream.all().flatMapThrowing { rows -> PostgresQueryResult in + guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else { + throw PSQLError.invalidCommandTag(rowStream.commandTag) + } + return PostgresQueryResult(metadata: metadata, rows: rows) + } + } + } + + /// Run a query on the Postgres server the connection is connected to and iterate the rows in a callback. + /// + /// - Note: This API does not support back-pressure. If you need back-pressure please use the query + /// API, that supports structured concurrency. + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - onRow: A closure that is invoked for every row. + /// - Returns: An EventLoopFuture, that allows access to the future ``PostgresQueryMetadata``. + public func query( + _ query: PostgresQuery, + logger: Logger, + file: String = #file, + line: Int = #line, + _ onRow: @escaping (PostgresRow) throws -> () + ) -> EventLoopFuture { + self.queryStream(query, logger: logger).flatMap { rowStream in + rowStream.onRow(onRow).flatMapThrowing { () -> PostgresQueryMetadata in + guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else { + throw PSQLError.invalidCommandTag(rowStream.commandTag) + } + return metadata + } + } + } +} + +// MARK: PostgresDatabase conformance extension PostgresConnection: PostgresDatabase { public func send( @@ -513,14 +573,14 @@ extension PostgresConnection: PostgresDatabase { switch command { case .query(let query, let onMetadata, let onRow): - resultFuture = self.query(query, logger: logger).flatMap { stream in + resultFuture = self.queryStream(query, logger: logger).flatMap { stream in return stream.onRow(onRow).map { _ in onMetadata(PostgresQueryMetadata(string: stream.commandTag)!) } } case .queryAll(let query, let onResult): - resultFuture = self.query(query, logger: logger).flatMap { rows in + resultFuture = self.queryStream(query, logger: logger).flatMap { rows in return rows.all().map { allrows in onResult(.init(metadata: PostgresQueryMetadata(string: rows.commandTag)!, rows: allrows)) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 13de8281..31a9ba1d 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1108,6 +1108,8 @@ extension ConnectionStateMachine { return true case .tooManyParameters: return true + case .invalidCommandTag: + return true case .connectionQuiescing: preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") case .connectionClosed: diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index cb09d12a..fd402618 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -11,6 +11,7 @@ struct PSQLError: Error { case unsupportedAuthMechanism(PSQLAuthScheme) case authMechanismRequiresPassword case saslError(underlyingError: Error) + case invalidCommandTag(String) case queryCancelled case tooManyParameters @@ -60,6 +61,10 @@ struct PSQLError: Error { Self.init(.saslError(underlyingError: underlying)) } + static func invalidCommandTag(_ value: String) -> PSQLError { + Self.init(.invalidCommandTag(value)) + } + static var queryCancelled: PSQLError { Self.init(.queryCancelled) } diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 233c925f..674b4273 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -26,7 +26,7 @@ extension PSQLError { return PostgresError.protocol("Unable to authenticate without password") case .saslError(underlyingError: let underlying): return underlying - case .tooManyParameters: + case .tooManyParameters, .invalidCommandTag: return self case .connectionQuiescing: return PostgresError.connectionClosed diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 2e4de247..38443c5f 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -59,10 +59,9 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query("SELECT version()", logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query("SELECT version()", logger: .psqlTest).wait()) + let rows = result?.rows var version: String? XCTAssertNoThrow(version = try rows?.first?.decode(String.self, context: .default)) XCTAssertEqual(version?.contains("PostgreSQL"), true) @@ -77,12 +76,9 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest).wait()) - + var metadata: PostgresQueryMetadata? var received: Int64 = 0 - - XCTAssertNoThrow(try stream?.onRow { row in + XCTAssertNoThrow(metadata = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest) { row in func workaround() { var number: Int64? XCTAssertNoThrow(number = try row.decode(Int64.self, context: .default)) @@ -94,6 +90,8 @@ final class IntegrationTests: XCTestCase { }.wait()) XCTAssertEqual(received, 10000) + XCTAssertEqual(metadata?.command, "SELECT") + XCTAssertEqual(metadata?.rows, 10000) } func test1kRoundTrips() { @@ -106,12 +104,10 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try conn?.close().wait()) } for _ in 0..<1_000 { - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query("SELECT version()", logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query("SELECT version()", logger: .psqlTest).wait()) var version: String? - XCTAssertNoThrow(version = try rows?.first?.decode(String.self, context: .default)) + XCTAssertNoThrow(version = try result?.rows.first?.decode(String.self, context: .default)) XCTAssertEqual(version?.contains("PostgreSQL"), true) } } @@ -125,12 +121,10 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query("SELECT \("hello")::TEXT as foo", logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query("SELECT \("hello")::TEXT as foo", logger: .psqlTest).wait()) var foo: String? - XCTAssertNoThrow(foo = try rows?.first?.decode(String.self, context: .default)) + XCTAssertNoThrow(foo = try result?.rows.first?.decode(String.self, context: .default)) XCTAssertEqual(foo, "hello") } @@ -143,8 +137,8 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query(""" + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" SELECT 1::SMALLINT as smallint, -32767::SMALLINT as smallint_min, @@ -157,10 +151,8 @@ final class IntegrationTests: XCTestCase { 9223372036854775807::BIGINT as bigint_max """, logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) - let row = rows?.first + XCTAssertEqual(result?.rows.count, 1) + let row = result?.rows.first var cells: (Int16, Int16, Int16, Int32, Int32, Int32, Int64, Int64, Int64)? XCTAssertNoThrow(cells = try row?.decode((Int16, Int16, Int16, Int32, Int32, Int32, Int64, Int64, Int64).self, context: .default)) @@ -185,13 +177,11 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? + var result: PostgresQueryResult? let array: [Int64] = [1, 2, 3] - XCTAssertNoThrow(stream = try conn?.query("SELECT \(array)::int8[] as array", logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(try rows?.first?.decode([Int64].self, context: .default), array) + XCTAssertNoThrow(result = try conn?.query("SELECT \(array)::int8[] as array", logger: .psqlTest).wait()) + XCTAssertEqual(result?.rows.count, 1) + XCTAssertEqual(try result?.rows.first?.decode([Int64].self, context: .default), array) } func testDecodeEmptyIntegerArray() { @@ -203,13 +193,11 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query("SELECT '{}'::int[] as array", logger: .psqlTest).wait()) + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query("SELECT '{}'::int[] as array", logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(try rows?.first?.decode([Int64].self, context: .default), []) + XCTAssertEqual(result?.rows.count, 1) + XCTAssertEqual(try result?.rows.first?.decode([Int64].self, context: .default), []) } func testDoubleArraySerialization() { @@ -221,13 +209,11 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? + var result: PostgresQueryResult? let doubles: [Double] = [3.14, 42] - XCTAssertNoThrow(stream = try conn?.query("SELECT \(doubles)::double precision[] as doubles", logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(try rows?.first?.decode([Double].self, context: .default), doubles) + XCTAssertNoThrow(result = try conn?.query("SELECT \(doubles)::double precision[] as doubles", logger: .psqlTest).wait()) + XCTAssertEqual(result?.rows.count, 1) + XCTAssertEqual(try result?.rows.first?.decode([Double].self, context: .default), doubles) } func testDecodeDates() { @@ -239,20 +225,18 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query(""" + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" SELECT '2016-01-18 01:02:03 +0042'::DATE as date, '2016-01-18 01:02:03 +0042'::TIMESTAMP as timestamp, '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz """, logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(result?.rows.count, 1) var cells: (Date, Date, Date)? - XCTAssertNoThrow(cells = try rows?.first?.decode((Date, Date, Date).self, context: .default)) + XCTAssertNoThrow(cells = try result?.rows.first?.decode((Date, Date, Date).self, context: .default)) XCTAssertEqual(cells?.0.description, "2016-01-18 00:00:00 +0000") XCTAssertEqual(cells?.1.description, "2016-01-18 01:02:03 +0000") @@ -268,24 +252,22 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow(try conn?.close().wait()) } - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query(""" + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" SELECT \(Decimal(string: "123456.789123")!)::numeric as numeric, \(Decimal(string: "-123456.789123")!)::numeric as numeric_negative """, logger: .psqlTest).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(result?.rows.count, 1) var cells: (Decimal, Decimal)? - XCTAssertNoThrow(cells = try rows?.first?.decode((Decimal, Decimal).self, context: .default)) + XCTAssertNoThrow(cells = try result?.rows.first?.decode((Decimal, Decimal).self, context: .default)) XCTAssertEqual(cells?.0, Decimal(string: "123456.789123")) XCTAssertEqual(cells?.1, Decimal(string: "-123456.789123")) } - func testDecodeUUID() { + func testRoundTripUUID() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() @@ -296,15 +278,15 @@ final class IntegrationTests: XCTestCase { let uuidString = "2c68f645-9ca6-468b-b193-ee97f241c2f8" - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query(""" + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" SELECT \(uuidString)::UUID as uuid - """, logger: .psqlTest).wait()) + """, + logger: .psqlTest + ).wait()) - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) - XCTAssertEqual(try rows?.first?.decode(UUID.self, context: .default), UUID(uuidString: uuidString)) + XCTAssertEqual(result?.rows.count, 1) + XCTAssertEqual(try result?.rows.first?.decode(UUID.self, context: .default), UUID(uuidString: uuidString)) } func testRoundTripJSONB() { @@ -322,33 +304,29 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try conn?.close().wait()) } do { - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query(""" + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" select \(Object(foo: 1, bar: 2))::jsonb as jsonb """, logger: .psqlTest).wait()) - - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) - var result: Object? - XCTAssertNoThrow(result = try rows?.first?.decode(Object.self, context: .default)) - XCTAssertEqual(result?.foo, 1) - XCTAssertEqual(result?.bar, 2) + + XCTAssertEqual(result?.rows.count, 1) + var obj: Object? + XCTAssertNoThrow(obj = try result?.rows.first?.decode(Object.self, context: .default)) + XCTAssertEqual(obj?.foo, 1) + XCTAssertEqual(obj?.bar, 2) } do { - var stream: PSQLRowStream? - XCTAssertNoThrow(stream = try conn?.query(""" + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" select \(Object(foo: 1, bar: 2))::json as json """, logger: .psqlTest).wait()) - - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try stream?.all().wait()) - XCTAssertEqual(rows?.count, 1) - var result: Object? - XCTAssertNoThrow(result = try rows?.first?.decode(Object.self, context: .default)) - XCTAssertEqual(result?.foo, 1) - XCTAssertEqual(result?.bar, 2) + + XCTAssertEqual(result?.rows.count, 1) + var obj: Object? + XCTAssertNoThrow(obj = try result?.rows.first?.decode(Object.self, context: .default)) + XCTAssertEqual(obj?.foo, 1) + XCTAssertEqual(obj?.bar, 2) } } } From def4fe8c8b4f58ce394771f27910dca107e9d8c3 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Mar 2022 13:57:37 +0100 Subject: [PATCH 098/292] Add `hasColumn` to `PostgresRandomAccessRow` (#270) --- Sources/PostgresNIO/Data/PostgresRow.swift | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 3fda262a..f7cfd238 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -178,6 +178,13 @@ extension PostgresRandomAccessRow: RandomAccessCollection { } return self[index] } + + /// Checks if the row contains a cell for the given column name. + /// - Parameter column: The column name to check against + /// - Returns: `true` if the row contains this column, `false` if it does not. + public func contains(_ column: String) -> Bool { + self.lookupTable[column] != nil + } } extension PostgresRandomAccessRow { @@ -286,8 +293,8 @@ extension PostgresRow { @available(*, deprecated, message: """ This call is O(n) where n is the number of cells in the row. For random access to cells - in a row create a PostgresRandomAccessCollection from the row first and use its subscript - methods. + in a row create a PostgresRandomAccessRow from the row first and use its subscript + methods. (see `makeRandomAccess()`) """) public func column(_ column: String) -> PostgresData? { guard let index = self.lookupTable[column] else { From 3c0efd755c689f7f6370a63206524487b14b4eda Mon Sep 17 00:00:00 2001 From: Nick Otto Date: Tue, 26 Apr 2022 02:41:46 -0400 Subject: [PATCH 099/292] Make `PostgresRowSequence.collect` public (#281) Fixes #279 --- Sources/PostgresNIO/New/PostgresRowSequence.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 8159e679..2298c541 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -531,7 +531,7 @@ extension AsyncStreamConsumer { } extension PostgresRowSequence { - func collect() async throws -> [PostgresRow] { + public func collect() async throws -> [PostgresRow] { var result = [PostgresRow]() for try await row in self { result.append(row) From 40b9f9938b03b347d3220c83760de9d1340fa314 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 26 Apr 2022 12:33:49 -0500 Subject: [PATCH 100/292] Switch this repo to doing its CI independently, plus some updates (#284) * Switch this repo to doing its CI independently, plus some updates * Workaround for git safe.directory issue (see checkout/actions#766) --- .github/workflows/main-codecov.yml | 13 ---------- .github/workflows/test.yml | 40 ++++++++++++++++++++++++------ 2 files changed, 32 insertions(+), 21 deletions(-) delete mode 100644 .github/workflows/main-codecov.yml diff --git a/.github/workflows/main-codecov.yml b/.github/workflows/main-codecov.yml deleted file mode 100644 index 85a794f1..00000000 --- a/.github/workflows/main-codecov.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: CI for main -on: - push: - branches: - - main -jobs: - update-main-codecov: - uses: vapor/ci/.github/workflows/run-unit-tests.yml@reusable-workflows - with: - with_coverage: true - with_tsan: true - coverage_ignores: '/Tests/' - test_filter: '^PostgresNIOTests' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a9ff286a..10190870 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,20 +1,25 @@ -name: test -on: [ 'pull_request' ] +name: CI +on: + push: + branches: + - "main" + pull_request: + branches: + - "*" jobs: linux-unit: strategy: fail-fast: false matrix: - swiftver: - - swift:5.2 - - swift:5.3 + swift: + - swift:5.4 - swift:5.5 - swift:5.6 - swiftlang/swift:nightly-main - swiftos: + os: - focal - container: ${{ format('{0}-{1}', matrix.swiftver, matrix.swiftos) }} + container: ${{ format('{0}-{1}', matrix.swift, matrix.os) }} runs-on: ubuntu-latest env: LOG_LEVEL: debug @@ -33,6 +38,7 @@ jobs: cc_dry_run: false linux-integration-and-dependencies: + if: github.event_name == 'pull_request' strategy: fail-fast: false matrix: @@ -104,6 +110,7 @@ jobs: run: swift test --package-path fluent-postgres-driver macos-all: + if: github.event_name == 'pull_request' strategy: fail-fast: false matrix: @@ -116,7 +123,7 @@ jobs: xcode: - latest-stable #- latest - runs-on: macos-11 + runs-on: macos-12 env: LOG_LEVEL: debug POSTGRES_HOSTNAME: 127.0.0.1 @@ -142,3 +149,20 @@ jobs: run: | swift test --enable-test-discovery -Xlinker -rpath \ -Xlinker $(xcode-select -p)/Toolchains/XcodeDefault.xctoolchain/usr/lib/swift-5.5/macosx + + api-breakage: + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + container: + image: swift:5.6-focal + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + # https://github.com/actions/checkout/issues/766 + - name: Mark the workspace as safe + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + - name: API breaking changes + run: | + swift package diagnose-api-breaking-changes origin/main From 0036a89525b7528e8fe35d6b41b414ee81fc0802 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 26 Apr 2022 22:37:09 +0200 Subject: [PATCH 101/292] Drop Swift 5.2 and 5.3 support (#287) SwiftNIO has dropped support for Swift 5.2 and 5.3. https://github.com/apple/swift-nio/pull/2080 We should do the same. --- Package.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index 2dacd63f..44d4edef 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.2 +// swift-tools-version:5.4 import PackageDescription let package = Package( From efd11c5fa6bc33c76a24379df8409a1d0e39fe2d Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 26 Apr 2022 22:52:16 +0200 Subject: [PATCH 102/292] Remove unused scripts (#288) --- scripts/check_no_api_breakages.sh | 122 ------------------------------ scripts/run_no_api_breakages.sh | 8 -- 2 files changed, 130 deletions(-) delete mode 100755 scripts/check_no_api_breakages.sh delete mode 100755 scripts/run_no_api_breakages.sh diff --git a/scripts/check_no_api_breakages.sh b/scripts/check_no_api_breakages.sh deleted file mode 100755 index 73c3fb46..00000000 --- a/scripts/check_no_api_breakages.sh +++ /dev/null @@ -1,122 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the SwiftNIO open source project -## -## Copyright (c) 2017-2020 Apple Inc. and the SwiftNIO project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of SwiftNIO project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu - -# repodir -function all_modules() { - local repodir="$1" - ( - set -eu - cd "$repodir" - swift package dump-package | jq '.products | - map(select(.type | has("library") )) | - map(.name) | .[]' | tr -d '"' - ) -} - -# repodir tag output -function build_and_do() { - local repodir=$1 - local tag=$2 - local output=$3 - - ( - cd "$repodir" - git checkout -q "$tag" - swift build --enable-test-discovery - while read -r module; do - swift api-digester -sdk "$sdk" -dump-sdk -module "$module" \ - -o "$output/$module.json" -I "$repodir/.build/debug" - done < <(all_modules "$repodir") - ) -} - -function usage() { - echo >&2 "Usage: $0 REPO-GITHUB-URL NEW-VERSION OLD-VERSIONS..." - echo >&2 - echo >&2 "This script requires a Swift 5.1+ toolchain." - echo >&2 - echo >&2 "Examples:" - echo >&2 - echo >&2 "Check between main and tag 2.1.1 of swift-nio:" - echo >&2 " $0 https://github.com/apple/swift-nio main 2.1.1" - echo >&2 - echo >&2 "Check between HEAD and commit 64cf63d7 using the provided toolchain:" - echo >&2 " xcrun --toolchain org.swift.5120190702a $0 ../some-local-repo HEAD 64cf63d7" -} - -if [[ $# -lt 3 ]]; then - usage - exit 1 -fi - -sdk=/ -if [[ "$(uname -s)" == Darwin ]]; then - sdk=$(xcrun --show-sdk-path) -fi - -hash jq 2> /dev/null || { echo >&2 "ERROR: jq must be installed"; exit 1; } -tmpdir=$(mktemp -d /tmp/.check-api_XXXXXX) -repo_url=$1 -new_tag=$2 -shift 2 - -repodir="$tmpdir/repo" -git clone "$repo_url" "$repodir" -git -C "$repodir" fetch -q origin '+refs/pull/*:refs/remotes/origin/pr/*' -errors=0 - -for old_tag in "$@"; do - mkdir "$tmpdir/api-old" - mkdir "$tmpdir/api-new" - - echo "Checking public API breakages from $old_tag to $new_tag" - - build_and_do "$repodir" "$new_tag" "$tmpdir/api-new/" - build_and_do "$repodir" "$old_tag" "$tmpdir/api-old/" - - for f in "$tmpdir/api-new"/*; do - f=$(basename "$f") - report="$tmpdir/$f.report" - if [[ ! -f "$tmpdir/api-old/$f" ]]; then - echo "NOTICE: NEW MODULE $f" - continue - fi - - echo -n "Checking $f... " - swift api-digester -sdk "$sdk" -diagnose-sdk \ - --input-paths "$tmpdir/api-old/$f" -input-paths "$tmpdir/api-new/$f" 2>&1 \ - > "$report" 2>&1 - - if ! shasum "$report" | grep -q cefc4ee5bb7bcdb7cb5a7747efa178dab3c794d5; then - echo ERROR - echo >&2 "==============================" - echo >&2 "ERROR: public API change in $f" - echo >&2 "==============================" - cat >&2 "$report" - errors=$(( errors + 1 )) - else - echo OK - fi - done - rm -rf "$tmpdir/api-new" "$tmpdir/api-old" -done - -if [[ "$errors" == 0 ]]; then - echo "OK, all seems good" -fi -echo done -exit "$errors" diff --git a/scripts/run_no_api_breakages.sh b/scripts/run_no_api_breakages.sh deleted file mode 100755 index 89bcba82..00000000 --- a/scripts/run_no_api_breakages.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -set -eu - -apt-get update -apt-get install -y jq - -./scripts/check_no_api_breakages.sh $1 $2 $3 From a89a2755301530afe394de87e6aa4444a0fa718a Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 27 Apr 2022 15:04:03 +0200 Subject: [PATCH 103/292] Update README to reflect latest changes (#289) --- README.md | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index e558e046..dc7e14ed 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,11 @@ PostgresNIO - - SSWG Incubating - - - Documentation - - - Team Chat - - - MIT License - - - Continuous Integration - - - Swift 5.2 - +[![SSWG Incubating Badge](https://img.shields.io/badge/sswg-incubating-green.svg)][SSWG Incubation] +[![Documentation](http://img.shields.io/badge/read_the-docs-2196f3.svg)][Documentation] +[![Team Chat](https://img.shields.io/discord/431917998102675485.svg)][Team Chat] +[![MIT License](http://img.shields.io/badge/license-MIT-brightgreen.svg)][MIT License] +[![Continuous Integration](https://github.com/vapor/postgres-nio/actions/workflows/test.yml/badge.svg)][Continuous Integration] +[![Swift 5.4](http://img.shields.io/badge/swift-5.4-brightgreen.svg)][Swift 5.4]

@@ -196,7 +184,15 @@ Some queries do not receive any rows from the server (most often `INSERT`, `UPDA ## Security -Please see [SECURITY.md](https://github.com/vapor/.github/blob/main/SECURITY.md) for details on the security process. +Please see [SECURITY.md] for details on the security process. + +[SSWG Incubation]: https://github.com/swift-server/sswg/blob/main/process/incubation.md#graduated-level +[Documentation]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/ +[Team Chat]: https://discord.gg/vapor +[MIT License]: LICENSE +[Continuous Integration]: https://github.com/vapor/postgres-nio/actions +[Swift 5.4]: https://swift.org +[Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresConnection/ [`query(_:logger:)`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresConnection/#postgresconnection.query(_:logger:file:line:) From a7a160bb861ca88b0c8f1cc11e15c4892ab897bf Mon Sep 17 00:00:00 2001 From: Nick Otto Date: Wed, 27 Apr 2022 18:10:40 -0400 Subject: [PATCH 104/292] Rename `PostgresCastingError` to `PostgresDecodingError` and make public (#286) Fixes #278. Co-authored-by: Fabian Fett --- Sources/PostgresNIO/Data/PostgresRow.swift | 6 +- .../New/Data/Array+PostgresCodable.swift | 10 +-- .../New/Data/Bool+PostgresCodable.swift | 12 +-- .../New/Data/Date+PostgresCodable.swift | 12 +-- .../New/Data/Decimal+PostgresCodable.swift | 8 +- .../New/Data/Float+PostgresCodable.swift | 20 ++--- .../New/Data/Int+PostgresCodable.swift | 42 ++++----- .../New/Data/JSON+PostgresCodable.swift | 6 +- .../RawRepresentable+PostgresCodable.swift | 4 +- .../New/Data/String+PostgresCodable.swift | 6 +- .../New/Data/UUID+PostgresCodable.swift | 14 +-- Sources/PostgresNIO/New/PSQLError.swift | 87 +++++++++---------- Sources/PostgresNIO/New/PostgresCell.swift | 4 +- Sources/PostgresNIO/New/PostgresCodable.swift | 8 +- .../New/PostgresRow-multi-decode.swift | 60 ++++++------- .../New/Data/Array+PSQLCodableTests.swift | 54 ++++++------ .../New/Data/Bool+PSQLCodableTests.swift | 32 +++---- .../New/Data/Date+PSQLCodableTests.swift | 34 ++++---- .../New/Data/Decimal+PSQLCodableTests.swift | 12 +-- .../New/Data/Float+PSQLCodableTests.swift | 40 ++++----- .../New/Data/JSON+PSQLCodableTests.swift | 28 +++--- .../RawRepresentable+PSQLCodableTests.swift | 16 ++-- .../New/Data/String+PSQLCodableTests.swift | 29 +++---- .../New/Data/UUID+PSQLCodableTests.swift | 46 +++++----- .../New/PostgresCellTests.swift | 2 +- .../New/PostgresCodableTests.swift | 10 +-- .../New/PostgresErrorTests.swift | 14 +-- dev/generate-postgresrow-multi-decode.sh | 4 +- 28 files changed, 306 insertions(+), 314 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index f7cfd238..028fe656 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -250,8 +250,8 @@ extension PostgresRandomAccessRow { var cellSlice = self.cells[index] do { return try T._decodeRaw(from: &cellSlice, type: column.dataType, format: column.format, context: context) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: self.columns[index].name, columnIndex: index, @@ -330,5 +330,3 @@ extension PostgresRow: Sendable {} extension PostgresRandomAccessRow: Sendable {} #endif - - diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index dd4e5620..c3bf3eb4 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -134,13 +134,13 @@ extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Elemen ) throws { guard case .binary = format else { // currently we only support decoding arrays in binary format. - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } guard let (isNotEmpty, b, element) = buffer.readMultipleIntegers(endianness: .big, as: (Int32, Int32, UInt32).self), 0 <= isNotEmpty, isNotEmpty <= 1, b == 0 else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } let elementType = PostgresDataType(element) @@ -154,7 +154,7 @@ extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Elemen expectedArrayCount > 0, dimensions == 1 else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } var result = Array() @@ -162,11 +162,11 @@ extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Elemen for _ in 0 ..< expectedArrayCount { guard let elementLength = buffer.readInteger(as: Int32.self), elementLength >= 0 else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } guard var elementBuffer = buffer.readSlice(length: numericCast(elementLength)) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } let element = try Element.init(from: &elementBuffer, type: elementType, format: format, context: context) diff --git a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift index 13308265..1aa264b8 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -9,13 +9,13 @@ extension Bool: PostgresDecodable { context: PostgresDecodingContext ) throws { guard type == .bool else { - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } switch format { case .binary: guard buffer.readableBytes == 1 else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } switch buffer.readInteger(as: UInt8.self) { @@ -24,11 +24,11 @@ extension Bool: PostgresDecodable { case .some(1): self = true default: - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } case .text: guard buffer.readableBytes == 1 else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } switch buffer.readInteger(as: UInt8.self) { @@ -37,7 +37,7 @@ extension Bool: PostgresDecodable { case .some(UInt8(ascii: "t")): self = true default: - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } } } @@ -47,7 +47,7 @@ extension Bool: PostgresEncodable { public static var psqlType: PostgresDataType { .bool } - + public static var psqlFormat: PostgresFormat { .binary } diff --git a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift index 4a1848ec..e32ecb10 100644 --- a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -5,7 +5,7 @@ extension Date: PostgresEncodable { public static var psqlType: PostgresDataType { .timestamptz } - + public static var psqlFormat: PostgresFormat { .binary } @@ -18,14 +18,14 @@ extension Date: PostgresEncodable { let seconds = self.timeIntervalSince(Self._psqlDateStart) * Double(Self._microsecondsPerSecond) byteBuffer.writeInteger(Int64(seconds)) } - + // MARK: Private Constants @usableFromInline static let _microsecondsPerSecond: Int64 = 1_000_000 @usableFromInline static let _secondsInDay: Int64 = 24 * 60 * 60 - + /// values are stored as seconds before or after midnight 2000-01-01 @usableFromInline static let _psqlDateStart = Date(timeIntervalSince1970: 946_684_800) @@ -42,18 +42,18 @@ extension Date: PostgresDecodable { switch type { case .timestamp, .timestamptz: guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } let seconds = Double(microseconds) / Double(Self._microsecondsPerSecond) self = Date(timeInterval: seconds, since: Self._psqlDateStart) case .date: guard buffer.readableBytes == 4, let days = buffer.readInteger(as: Int32.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } let seconds = Int64(days) * Self._secondsInDay self = Date(timeInterval: Double(seconds), since: Self._psqlDateStart) default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } diff --git a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift index 3f1c7fa0..4ab96386 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift @@ -5,7 +5,7 @@ extension Decimal: PostgresEncodable { public static var psqlType: PostgresDataType { .numeric } - + public static var psqlFormat: PostgresFormat { .binary } @@ -34,16 +34,16 @@ extension Decimal: PostgresDecodable { switch (format, type) { case (.binary, .numeric): guard let numeric = PostgresNumeric(buffer: &buffer) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = numeric.decimal case (.text, .numeric): guard let string = buffer.readString(length: buffer.readableBytes), let value = Decimal(string: string) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } diff --git a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift index d653e9d8..7943c152 100644 --- a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -4,7 +4,7 @@ extension Float: PostgresEncodable { public static var psqlType: PostgresDataType { .float4 } - + public static var psqlFormat: PostgresFormat { .binary } @@ -29,21 +29,21 @@ extension Float: PostgresDecodable { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = float case (.binary, .float8): guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = Float(double) case (.text, .float4), (.text, .float8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Float(string) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } @@ -54,7 +54,7 @@ extension Double: PostgresEncodable { public static var psqlType: PostgresDataType { .float8 } - + public static var psqlFormat: PostgresFormat { .binary } @@ -79,21 +79,21 @@ extension Double: PostgresDecodable { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.psqlReadFloat() else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = Double(float) case (.binary, .float8): guard buffer.readableBytes == 8, let double = buffer.psqlReadDouble() else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = double case (.text, .float4), (.text, .float8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Double(string) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } diff --git a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift index 7ea81f31..e4a2492d 100644 --- a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -31,12 +31,12 @@ extension UInt8: PostgresDecodable { switch type { case .bpchar, .char: guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } @@ -74,16 +74,16 @@ extension Int16: PostgresDecodable { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value case (.text, .int2): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int16(string) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } @@ -96,7 +96,7 @@ extension Int32: PostgresEncodable { public static var psqlType: PostgresDataType { .int4 } - + public static var psqlFormat: PostgresFormat { .binary } @@ -121,21 +121,21 @@ extension Int32: PostgresDecodable { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = Int32(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = Int32(value) case (.text, .int2), (.text, .int4): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int32(string) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } @@ -173,26 +173,26 @@ extension Int64: PostgresDecodable { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = Int64(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = Int64(value) case (.binary, .int8): guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int64.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value case (.text, .int2), (.text, .int4), (.text, .int8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int64(string) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } @@ -212,7 +212,7 @@ extension Int: PostgresEncodable { preconditionFailure("Int is expected to be an Int32 or Int64") } } - + public static var psqlFormat: PostgresFormat { .binary } @@ -237,26 +237,26 @@ extension Int: PostgresDecodable { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = Int(value) case (.binary, .int4): guard buffer.readableBytes == 4, let value = buffer.readInteger(as: Int32.self).flatMap({ Int(exactly: $0) }) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value case (.binary, .int8): guard buffer.readableBytes == 8, let value = buffer.readInteger(as: Int.self).flatMap({ Int(exactly: $0) }) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value case (.text, .int2), (.text, .int4), (.text, .int8): guard let string = buffer.readString(length: buffer.readableBytes), let value = Int(string) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = value default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } diff --git a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift index 2e09d03e..d5696bf2 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift @@ -10,7 +10,7 @@ extension PostgresEncodable where Self: Encodable { public static var psqlType: PostgresDataType { .jsonb } - + public static var psqlFormat: PostgresFormat { .binary } @@ -35,13 +35,13 @@ extension PostgresDecodable where Self: Decodable { switch (format, type) { case (.binary, .jsonb): guard JSONBVersionByte == buffer.readInteger(as: UInt8.self) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = try context.jsonDecoder.decode(Self.self, from: buffer) case (.binary, .json), (.text, .jsonb), (.text, .json): self = try context.jsonDecoder.decode(Self.self, from: buffer) default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift index 9a4f6b1d..4c0195e3 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -4,7 +4,7 @@ extension PostgresEncodable where Self: RawRepresentable, RawValue: PostgresEnco public static var psqlType: PostgresDataType { RawValue.psqlType } - + public static var psqlFormat: PostgresFormat { RawValue.psqlFormat } @@ -27,7 +27,7 @@ extension PostgresDecodable where Self: RawRepresentable, RawValue: PostgresDeco ) throws { guard let rawValue = try? RawValue(from: &buffer, type: type, format: format, context: context), let selfValue = Self.init(rawValue: rawValue) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = selfValue diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index aebfedcd..8efb8155 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -5,7 +5,7 @@ extension String: PostgresEncodable { public static var psqlType: PostgresDataType { .text } - + public static var psqlFormat: PostgresFormat { .binary } @@ -37,11 +37,11 @@ extension String: PostgresDecodable { self = buffer.readString(length: buffer.readableBytes)! case (_, .uuid): guard let uuid = try? UUID(from: &buffer, type: .uuid, format: format, context: context) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = uuid.uuidString default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index f40fff7c..3241ea01 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -6,7 +6,7 @@ extension UUID: PostgresEncodable { public static var psqlType: PostgresDataType { .uuid } - + public static var psqlFormat: PostgresFormat { .binary } @@ -37,7 +37,7 @@ extension UUID: PostgresDecodable { switch (format, type) { case (.binary, .uuid): guard let uuid = buffer.readUUID() else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = uuid case (.binary, .varchar), @@ -46,15 +46,15 @@ extension UUID: PostgresDecodable { (.text, .text), (.text, .varchar): guard buffer.readableBytes == 36 else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } guard let uuid = buffer.readString(length: 36).flatMap({ UUID(uuidString: $0) }) else { - throw PostgresCastingError.Code.failure + throw PostgresDecodingError.Code.failure } self = uuid default: - throw PostgresCastingError.Code.typeMismatch + throw PostgresDecodingError.Code.typeMismatch } } } @@ -67,13 +67,13 @@ extension ByteBuffer { guard self.readableBytes >= MemoryLayout.size else { return nil } - + let value: UUID = self.getUUID(at: self.readerIndex)! /* must work as we have enough bytes */ // should be MoveReaderIndex self.moveReaderIndex(forwardBy: MemoryLayout.size) return value } - + func getUUID(at index: Int) -> UUID? { var uuid: uuid_t = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) return self.viewBytes(at: index, length: MemoryLayout.size(ofValue: uuid)).map { bufferBytes in diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index fd402618..2320c822 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -1,7 +1,7 @@ import NIOCore struct PSQLError: Error { - + enum Base { case sslUnsupported case failedToAddSSLHandler(underlying: Error) @@ -19,44 +19,44 @@ struct PSQLError: Error { case connectionClosed case connectionError(underlying: Error) case uncleanShutdown - - case casting(PostgresCastingError) + + case casting(PostgresDecodingError) } - + internal var base: Base - + private init(_ base: Base) { self.base = base } - + static var sslUnsupported: PSQLError { Self.init(.sslUnsupported) } - + static func failedToAddSSLHandler(underlying error: Error) -> PSQLError { Self.init(.failedToAddSSLHandler(underlying: error)) } - + static func server(_ message: PostgresBackendMessage.ErrorResponse) -> PSQLError { Self.init(.server(message)) } - + static func decoding(_ error: PSQLDecodingError) -> PSQLError { Self.init(.decoding(error)) } - + static func unexpectedBackendMessage(_ message: PostgresBackendMessage) -> PSQLError { Self.init(.unexpectedBackendMessage(message)) } - + static func unsupportedAuthMechanism(_ authScheme: PSQLAuthScheme) -> PSQLError { Self.init(.unsupportedAuthMechanism(authScheme)) } - + static var authMechanismRequiresPassword: PSQLError { Self.init(.authMechanismRequiresPassword) } - + static func sasl(underlying: Error) -> PSQLError { Self.init(.saslError(underlyingError: underlying)) } @@ -68,33 +68,31 @@ struct PSQLError: Error { static var queryCancelled: PSQLError { Self.init(.queryCancelled) } - + static var tooManyParameters: PSQLError { Self.init(.tooManyParameters) } - + static var connectionQuiescing: PSQLError { Self.init(.connectionQuiescing) } - + static var connectionClosed: PSQLError { Self.init(.connectionClosed) } - + static func channel(underlying: Error) -> PSQLError { Self.init(.connectionError(underlying: underlying)) } - + static var uncleanShutdown: PSQLError { Self.init(.uncleanShutdown) } } /// An error that may happen when a ``PostgresRow`` or ``PostgresCell`` is decoded to native Swift types. -@usableFromInline -struct PostgresCastingError: Error, Equatable { - @usableFromInline - struct Code: Hashable, Error { +public struct PostgresDecodingError: Error, Equatable { + public struct Code: Hashable, Error { enum Base { case missingData case typeMismatch @@ -107,41 +105,31 @@ struct PostgresCastingError: Error, Equatable { self.base = base } - @usableFromInline - static let missingData = Self.init(.missingData) - @usableFromInline - static let typeMismatch = Self.init(.typeMismatch) - @usableFromInline - static let failure = Self.init(.failure) + public static let missingData = Self.init(.missingData) + public static let typeMismatch = Self.init(.typeMismatch) + public static let failure = Self.init(.failure) } /// The casting error code - let code: Code + public let code: Code /// The cell's column name for which the casting failed - let columnName: String + public let columnName: String /// The cell's column index for which the casting failed - let columnIndex: Int + public let columnIndex: Int /// The swift type the cell should have been casted into - let targetType: Any.Type + public let targetType: Any.Type /// The cell's postgres data type for which the casting failed - let postgresType: PostgresDataType + public let postgresType: PostgresDataType /// The cell's postgres format for which the casting failed - let postgresFormat: PostgresFormat + public let postgresFormat: PostgresFormat /// A copy of the cell data which was attempted to be casted - let postgresData: ByteBuffer? + public let postgresData: ByteBuffer? /// The file the casting/decoding was attempted in - let file: String + public let file: String /// The line the casting/decoding was attempted in - let line: Int - - var description: String { - // This may seem very odd... But we are afraid that users might accidentally send the - // unfiltered errors out to end-users. This may leak security relevant information. For this - // reason we overwrite the error description by default to this generic "Database error" - "Database error" - } + public let line: Int @usableFromInline init( @@ -166,8 +154,7 @@ struct PostgresCastingError: Error, Equatable { self.line = line } - @usableFromInline - static func ==(lhs: PostgresCastingError, rhs: PostgresCastingError) -> Bool { + public static func ==(lhs: PostgresDecodingError, rhs: PostgresDecodingError) -> Bool { return lhs.code == rhs.code && lhs.columnName == rhs.columnName && lhs.columnIndex == rhs.columnIndex @@ -180,6 +167,14 @@ struct PostgresCastingError: Error, Equatable { } } +extension PostgresDecodingError: CustomStringConvertible { + public var description: String { + // This may seem very odd... But we are afraid that users might accidentally send the + // unfiltered errors out to end-users. This may leak security relevant information. For this + // reason we overwrite the error description by default to this generic "Database error" + "Database error" + } +} enum PSQLAuthScheme { case none case kerberosV5 diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index 5281a798..8d11c78b 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -39,8 +39,8 @@ extension PostgresCell { format: self.format, context: context ) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: self.columnName, columnIndex: self.columnIndex, diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index c90594cf..6a40b4bf 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -5,10 +5,10 @@ import Foundation public protocol PostgresEncodable { /// identifies the data type that we will encode into `byteBuffer` in `encode` static var psqlType: PostgresDataType { get } - + /// identifies the postgres format that is used to encode the value into `byteBuffer` in `encode` static var psqlFormat: PostgresFormat { get } - + /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting /// the byte count. This method is called from the ``PostgresBindings``. func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) throws @@ -61,7 +61,7 @@ extension PostgresDecodable { context: PostgresDecodingContext ) throws -> Self { guard var buffer = byteBuffer else { - throw PostgresCastingError.Code.missingData + throw PostgresDecodingError.Code.missingData } return try self.init(from: &buffer, type: type, format: format, context: context) } @@ -84,7 +84,7 @@ extension PostgresEncodable { // The value of the parameter, in the format indicated by the associated format // code. n is the above length. try self.encode(into: &buffer, context: context) - + // overwrite the empty length, with the real value buffer.setInteger(numericCast(buffer.writerIndex - startIndex), at: lengthIndex, as: Int32.self) } diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift index 6ca0e54b..d5386b08 100644 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift @@ -16,8 +16,8 @@ extension PostgresRow { let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -52,8 +52,8 @@ extension PostgresRow { let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -94,8 +94,8 @@ extension PostgresRow { let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -142,8 +142,8 @@ extension PostgresRow { let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -196,8 +196,8 @@ extension PostgresRow { let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -256,8 +256,8 @@ extension PostgresRow { let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -322,8 +322,8 @@ extension PostgresRow { let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -394,8 +394,8 @@ extension PostgresRow { let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -472,8 +472,8 @@ extension PostgresRow { let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -556,8 +556,8 @@ extension PostgresRow { let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -646,8 +646,8 @@ extension PostgresRow { let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -742,8 +742,8 @@ extension PostgresRow { let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -844,8 +844,8 @@ extension PostgresRow { let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -952,8 +952,8 @@ extension PostgresRow { let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, @@ -1066,8 +1066,8 @@ extension PostgresRow { let r14 = try T14._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14) - } catch let code as PostgresCastingError.Code { - throw PostgresCastingError( + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( code: code, columnName: column.name, columnIndex: columnIndex, diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index 3798dab0..7b112b08 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -3,7 +3,7 @@ import NIOCore @testable import PostgresNIO class Array_PSQLCodableTests: XCTestCase { - + func testArrayTypes() { XCTAssertEqual(Bool.psqlArrayType, .boolArray) @@ -29,7 +29,7 @@ class Array_PSQLCodableTests: XCTestCase { XCTAssertEqual(Int64.psqlArrayType, .int8Array) XCTAssertEqual(Int64.psqlType, .int8) XCTAssertEqual([Int64].psqlType, .int8Array) - + #if (arch(i386) || arch(arm)) XCTAssertEqual(Int.psqlArrayType, .int4Array) XCTAssertEqual(Int.psqlType, .int4) @@ -56,61 +56,61 @@ class Array_PSQLCodableTests: XCTestCase { XCTAssertEqual(UUID.psqlType, .uuid) XCTAssertEqual([UUID].psqlType, .uuidArray) } - + func testStringArrayRoundTrip() { let values = ["foo", "bar", "hello", "world"] - + var buffer = ByteBuffer() XCTAssertNoThrow(try values.encode(into: &buffer, context: .default)) - + var result: [String]? XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) XCTAssertEqual(values, result) } - + func testEmptyStringArrayRoundTrip() { let values: [String] = [] - + var buffer = ByteBuffer() XCTAssertNoThrow(try values.encode(into: &buffer, context: .default)) - + var result: [String]? XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) XCTAssertEqual(values, result) } - + func testDecodeFailureIsNotEmptyOutOfScope() { var buffer = ByteBuffer() buffer.writeInteger(Int32(2)) // invalid value buffer.writeInteger(Int32(0)) buffer.writeInteger(String.psqlType.rawValue) - + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeFailureSecondValueIsUnexpected() { var buffer = ByteBuffer() buffer.writeInteger(Int32(0)) // is empty buffer.writeInteger(Int32(1)) // invalid value, must always be 0 buffer.writeInteger(String.psqlType.rawValue) - + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeFailureTriesDecodeInt8() { let value: Int64 = 1 << 32 var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeFailureInvalidNumberOfArrayElements() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // invalid value @@ -120,10 +120,10 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(1)) // dimensions... must be one XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeFailureInvalidNumberOfDimensions() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // invalid value @@ -131,12 +131,12 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(String.psqlType.rawValue) buffer.writeInteger(Int32(1)) // expected element count buffer.writeInteger(Int32(2)) // dimensions... must be one - + XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeUnexpectedEnd() { var unexpectedEndInElementLengthBuffer = ByteBuffer() unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // invalid value @@ -145,11 +145,11 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // expected element count unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 - + XCTAssertThrowsError(try [String](from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } - + var unexpectedEndInElementBuffer = ByteBuffer() unexpectedEndInElementBuffer.writeInteger(Int32(1)) // invalid value unexpectedEndInElementBuffer.writeInteger(Int32(0)) @@ -158,9 +158,9 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 unexpectedEndInElementBuffer.writeString("Hello World") // only 11 bytes, 12 needed! - + XCTAssertThrowsError(try [String](from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } } diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index 9526fcd6..e6e43f0b 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -3,27 +3,27 @@ import NIOCore @testable import PostgresNIO class Bool_PSQLCodableTests: XCTestCase { - + // MARK: - Binary - + func testBinaryTrueRoundTrip() { let value = true - + var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) XCTAssertEqual(Bool.psqlType, .bool) XCTAssertEqual(Bool.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 1) XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) - + var result: Bool? XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) XCTAssertEqual(value, result) } - + func testBinaryFalseRoundTrip() { let value = false - + var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) XCTAssertEqual(Bool.psqlType, .bool) @@ -35,30 +35,30 @@ class Bool_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) XCTAssertEqual(value, result) } - + func testBinaryDecodeBoolInvalidLength() { var buffer = ByteBuffer() buffer.writeInteger(Int64(1)) XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testBinaryDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } // MARK: - Text - + func testTextTrueDecode() { let value = true - + var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "t")) @@ -66,10 +66,10 @@ class Bool_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) XCTAssertEqual(value, result) } - + func testTextFalseDecode() { let value = false - + var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "f")) @@ -77,13 +77,13 @@ class Bool_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) XCTAssertEqual(value, result) } - + func testTextDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .text, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } } diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 9fe0e67b..b08c2de2 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -3,10 +3,10 @@ import NIOCore @testable import PostgresNIO class Date_PSQLCodableTests: XCTestCase { - + func testNowRoundTrip() { let value = Date() - + var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) XCTAssertEqual(Date.psqlType, .timestamptz) @@ -16,7 +16,7 @@ class Date_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) XCTAssertEqual(value, result) } - + func testDecodeRandomDate() { var buffer = ByteBuffer() buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) @@ -25,25 +25,25 @@ class Date_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) XCTAssertNotNil(result) } - + func testDecodeFailureInvalidLength() { var buffer = ByteBuffer() buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) XCTAssertThrowsError(try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeDate() { var firstDateBuffer = ByteBuffer() firstDateBuffer.writeInteger(Int32.min) - + var firstDate: Date? XCTAssertNoThrow(firstDate = try Date(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(firstDate) - + var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) @@ -51,39 +51,39 @@ class Date_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(lastDate = try Date(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(lastDate) } - + func testDecodeDateFromTimestamp() { var firstDateBuffer = ByteBuffer() firstDateBuffer.writeInteger(Int32.min) - + var firstDate: Date? XCTAssertNoThrow(firstDate = try Date(from: &firstDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(firstDate) - + var lastDateBuffer = ByteBuffer() lastDateBuffer.writeInteger(Int32.max) - + var lastDate: Date? XCTAssertNoThrow(lastDate = try Date(from: &lastDateBuffer, type: .date, format: .binary, context: .default)) XCTAssertNotNil(lastDate) } - + func testDecodeDateFailsWithToMuchData() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) XCTAssertThrowsError(try Date(from: &buffer, type: .date, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeDateFailsWithWrongDataType() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) XCTAssertThrowsError(try Date(from: &buffer, type: .int8, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) } } - + } diff --git a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift index cfb7f7e3..f9d57397 100644 --- a/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Decimal+PSQLCodableTests.swift @@ -3,10 +3,10 @@ import NIOCore @testable import PostgresNIO class Decimal_PSQLCodableTests: XCTestCase { - + func testRoundTrip() { let values: [Decimal] = [1.1, .pi, -5e-12] - + for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) @@ -17,14 +17,14 @@ class Decimal_PSQLCodableTests: XCTestCase { XCTAssertEqual(value, result) } } - + func testDecodeFailureInvalidType() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) - + XCTAssertThrowsError(try Decimal(from: &buffer, type: .int8, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) } } - + } diff --git a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift index 9fd1bb9e..728b87b7 100644 --- a/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Float+PSQLCodableTests.swift @@ -3,10 +3,10 @@ import NIOCore @testable import PostgresNIO class Float_PSQLCodableTests: XCTestCase { - + func testRoundTripDoubles() { let values: [Double] = [1.1, .pi, -5e-12] - + for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) @@ -18,10 +18,10 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(value, result) } } - + func testRoundTripFloat() { let values: [Float] = [1.1, .pi, -5e-12] - + for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) @@ -33,10 +33,10 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(value, result) } } - + func testRoundTripDoubleNaN() { let value: Double = .nan - + var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) XCTAssertEqual(Double.psqlType, .float8) @@ -46,10 +46,10 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Double(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(result?.isNaN, true) } - + func testRoundTripDoubleInfinity() { let value: Double = .infinity - + var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) XCTAssertEqual(Double.psqlType, .float8) @@ -59,10 +59,10 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Double(from: &buffer, type: .float8, format: .binary, context: .default)) XCTAssertEqual(result?.isInfinite, true) } - + func testRoundTripFromFloatToDouble() { let values: [Float] = [1.1, .pi, -5e-12] - + for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) @@ -74,10 +74,10 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(result, Double(value)) } } - + func testRoundTripFromDoubleToFloat() { let values: [Double] = [1.1, .pi, -5e-12] - + for value in values { var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) @@ -89,7 +89,7 @@ class Float_PSQLCodableTests: XCTestCase { XCTAssertEqual(result, Float(value)) } } - + func testDecodeFailureInvalidLength() { var eightByteBuffer = ByteBuffer() eightByteBuffer.writeInteger(Int64(0)) @@ -98,37 +98,37 @@ class Float_PSQLCodableTests: XCTestCase { var toLongBuffer1 = eightByteBuffer XCTAssertThrowsError(try Double(from: &toLongBuffer1, type: .float4, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } var toLongBuffer2 = eightByteBuffer XCTAssertThrowsError(try Float(from: &toLongBuffer2, type: .float4, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } var toShortBuffer1 = fourByteBuffer XCTAssertThrowsError(try Double(from: &toShortBuffer1, type: .float8, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } var toShortBuffer2 = fourByteBuffer XCTAssertThrowsError(try Float(from: &toShortBuffer2, type: .float8, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeFailureInvalidType() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) var copy1 = buffer XCTAssertThrowsError(try Double(from: ©1, type: .int8, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) } var copy2 = buffer XCTAssertThrowsError(try Float(from: ©2, type: .int8, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) } } } diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index dbaa43ee..858b6ede 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -3,21 +3,21 @@ import NIOCore @testable import PostgresNIO class JSON_PSQLCodableTests: XCTestCase { - + struct Hello: Equatable, Codable, PostgresCodable { let hello: String - + init(name: String) { self.hello = name } } - + func testRoundTrip() { var buffer = ByteBuffer() let hello = Hello(name: "world") XCTAssertNoThrow(try hello.encode(into: &buffer, context: .default)) XCTAssertEqual(Hello.psqlType, .jsonb) - + // verify jsonb prefix byte XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) @@ -25,7 +25,7 @@ class JSON_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Hello(from: &buffer, type: .jsonb, format: .binary, context: .default)) XCTAssertEqual(result, hello) } - + func testDecodeFromJSON() { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) @@ -34,14 +34,14 @@ class JSON_PSQLCodableTests: XCTestCase { XCTAssertNoThrow(result = try Hello(from: &buffer, type: .json, format: .binary, context: .default)) XCTAssertEqual(result, Hello(name: "world")) } - + func testDecodeFromJSONAsText() { let combinations : [(PostgresFormat, PostgresDataType)] = [ (.text, .json), (.text, .jsonb), ] var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) - + for (format, dataType) in combinations { var loopBuffer = buffer var result: Hello? @@ -49,29 +49,29 @@ class JSON_PSQLCodableTests: XCTestCase { XCTAssertEqual(result, Hello(name: "world")) } } - + func testDecodeFromJSONBWithoutVersionPrefixByte() { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) XCTAssertThrowsError(try Hello(from: &buffer, type: .jsonb, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeFromJSONBWithWrongDataType() { var buffer = ByteBuffer() buffer.writeString(#"{"hello":"world"}"#) XCTAssertThrowsError(try Hello(from: &buffer, type: .text, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) } } - + func testCustomEncoderIsUsed() { class TestEncoder: PostgresJSONEncoder { var encodeHits = 0 - + func encode(_ value: T, into buffer: inout ByteBuffer) throws where T : Encodable { self.encodeHits += 1 } @@ -80,7 +80,7 @@ class JSON_PSQLCodableTests: XCTestCase { preconditionFailure() } } - + let hello = Hello(name: "world") let encoder = TestEncoder() var buffer = ByteBuffer() diff --git a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift index a0808daf..0868a4ee 100644 --- a/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/RawRepresentable+PSQLCodableTests.swift @@ -3,16 +3,16 @@ import NIOCore @testable import PostgresNIO class RawRepresentable_PSQLCodableTests: XCTestCase { - + enum MyRawRepresentable: Int16, PostgresCodable { case testing = 1 case staging = 2 case production = 3 } - + func testRoundTrip() { let values: [MyRawRepresentable] = [.testing, .staging, .production] - + for value in values { var buffer = ByteBuffer() XCTAssertNoThrow(try value.encode(into: &buffer, context: .default)) @@ -24,23 +24,23 @@ class RawRepresentable_PSQLCodableTests: XCTestCase { XCTAssertEqual(value, result) } } - + func testDecodeInvalidRawTypeValue() { var buffer = ByteBuffer() buffer.writeInteger(Int16(4)) // out of bounds XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int16.psqlType, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + func testDecodeInvalidUnderlyingTypeValue() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // out of bounds XCTAssertThrowsError(try MyRawRepresentable(from: &buffer, type: Int32.psqlType, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } - + } diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 42edbda5..614749c1 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -3,26 +3,26 @@ import NIOCore @testable import PostgresNIO class String_PSQLCodableTests: XCTestCase { - + func testEncode() { let value = "Hello World" var buffer = ByteBuffer() - + value.encode(into: &buffer, context: .default) - + XCTAssertEqual(String.psqlType, .text) XCTAssertEqual(buffer.readString(length: buffer.readableBytes), value) } - + func testDecodeStringFromTextVarchar() { let expected = "Hello World" var buffer = ByteBuffer() buffer.writeString(expected) - + let dataTypes: [PostgresDataType] = [ .text, .varchar, .name ] - + for dataType in dataTypes { var loopBuffer = buffer var result: String? @@ -30,39 +30,38 @@ class String_PSQLCodableTests: XCTestCase { XCTAssertEqual(result, expected) } } - + func testDecodeFailureFromInvalidType() { let buffer = ByteBuffer() let dataTypes: [PostgresDataType] = [.bool, .float4Array, .float8Array, .bpchar] - + for dataType in dataTypes { var loopBuffer = buffer XCTAssertThrowsError(try String(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) } } } - + func testDecodeFromUUID() { let uuid = UUID() var buffer = ByteBuffer() uuid.encode(into: &buffer, context: .default) - + var decoded: String? XCTAssertNoThrow(decoded = try String(from: &buffer, type: .uuid, format: .binary, context: .default)) XCTAssertEqual(decoded, uuid.uuidString) } - + func testDecodeFailureFromInvalidUUID() { let uuid = UUID() var buffer = ByteBuffer() uuid.encode(into: &buffer, context: .default) // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - + XCTAssertThrowsError(try String(from: &buffer, type: .uuid, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } } - diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 0693f7f4..2ca2d1d8 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -3,19 +3,19 @@ import NIOCore @testable import PostgresNIO class UUID_PSQLCodableTests: XCTestCase { - + func testRoundTrip() { for _ in 0..<100 { let uuid = UUID() var buffer = ByteBuffer() - + uuid.encode(into: &buffer, context: .default) - + XCTAssertEqual(UUID.psqlType, .uuid) XCTAssertEqual(UUID.psqlFormat, .binary) XCTAssertEqual(buffer.readableBytes, 16) var byteIterator = buffer.readableBytesView.makeIterator() - + XCTAssertEqual(byteIterator.next(), uuid.uuid.0) XCTAssertEqual(byteIterator.next(), uuid.uuid.1) XCTAssertEqual(byteIterator.next(), uuid.uuid.2) @@ -32,13 +32,13 @@ class UUID_PSQLCodableTests: XCTestCase { XCTAssertEqual(byteIterator.next(), uuid.uuid.13) XCTAssertEqual(byteIterator.next(), uuid.uuid.14) XCTAssertEqual(byteIterator.next(), uuid.uuid.15) - + var decoded: UUID? XCTAssertNoThrow(decoded = try UUID(from: &buffer, type: .uuid, format: .binary, context: .default)) XCTAssertEqual(decoded, uuid) } } - + func testDecodeFromString() { let options: [(PostgresFormat, PostgresDataType)] = [ (.binary, .text), @@ -47,24 +47,24 @@ class UUID_PSQLCodableTests: XCTestCase { (.text, .text), (.text, .varchar), ] - + for _ in 0..<100 { // use uppercase let uuid = UUID() var lowercaseBuffer = ByteBuffer() lowercaseBuffer.writeString(uuid.uuidString.lowercased()) - + for (format, dataType) in options { var loopBuffer = lowercaseBuffer var decoded: UUID? XCTAssertNoThrow(decoded = try UUID(from: &loopBuffer, type: dataType, format: format, context: .default)) XCTAssertEqual(decoded, uuid) } - + // use lowercase var uppercaseBuffer = ByteBuffer() uppercaseBuffer.writeString(uuid.uuidString) - + for (format, dataType) in options { var loopBuffer = uppercaseBuffer var decoded: UUID? @@ -73,48 +73,48 @@ class UUID_PSQLCodableTests: XCTestCase { } } } - + func testDecodeFailureFromBytes() { let uuid = UUID() var buffer = ByteBuffer() - + uuid.encode(into: &buffer, context: .default) // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - + XCTAssertThrowsError(try UUID(from: &buffer, type: .uuid, format: .binary, context: .default)) { error in - XCTAssertEqual(error as? PostgresCastingError.Code, .failure) + XCTAssertEqual(error as? PostgresDecodingError.Code, .failure) } } - + func testDecodeFailureFromString() { let uuid = UUID() var buffer = ByteBuffer() buffer.writeString(uuid.uuidString) // this makes only 15 bytes readable. this should lead to an error buffer.moveReaderIndex(forwardBy: 1) - + let dataTypes: [PostgresDataType] = [.varchar, .text] - + for dataType in dataTypes { var loopBuffer = buffer XCTAssertThrowsError(try UUID(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .failure) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } } - + func testDecodeFailureFromInvalidPostgresType() { let uuid = UUID() var buffer = ByteBuffer() buffer.writeString(uuid.uuidString) - + let dataTypes: [PostgresDataType] = [.bool, .int8, .int2, .int4Array] - + for dataType in dataTypes { - var copy = buffer + var copy = buffer XCTAssertThrowsError(try UUID(from: ©, type: dataType, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresCastingError.Code, .typeMismatch) + XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) } } } diff --git a/Tests/PostgresNIOTests/New/PostgresCellTests.swift b/Tests/PostgresNIOTests/New/PostgresCellTests.swift index e7d1cb30..df7cbfd9 100644 --- a/Tests/PostgresNIOTests/New/PostgresCellTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCellTests.swift @@ -40,7 +40,7 @@ final class PostgresCellTests: XCTestCase { ) XCTAssertThrowsError(try cell.decode(Int?.self, context: .default)) { - guard let error = $0 as? PostgresCastingError else { + guard let error = $0 as? PostgresDecodingError else { return XCTFail("Unexpected error") } diff --git a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift index 0a3096e8..ef76e22a 100644 --- a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift @@ -53,12 +53,12 @@ final class PostgresCodableTests: XCTestCase { ) XCTAssertThrowsError(try row.decode(String.self, context: .default)) { - XCTAssertEqual(($0 as? PostgresCastingError)?.line, #line - 1) - XCTAssertEqual(($0 as? PostgresCastingError)?.file, #file) + XCTAssertEqual(($0 as? PostgresDecodingError)?.line, #line - 1) + XCTAssertEqual(($0 as? PostgresDecodingError)?.file, #file) - XCTAssertEqual(($0 as? PostgresCastingError)?.code, .missingData) - XCTAssert(($0 as? PostgresCastingError)?.targetType == String.self) - XCTAssertEqual(($0 as? PostgresCastingError)?.postgresType, .text) + XCTAssertEqual(($0 as? PostgresDecodingError)?.code, .missingData) + XCTAssert(($0 as? PostgresDecodingError)?.targetType == String.self) + XCTAssertEqual(($0 as? PostgresDecodingError)?.postgresType, .text) } } } diff --git a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift index 79f673c1..a3f44980 100644 --- a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift @@ -1,9 +1,9 @@ @testable import PostgresNIO import XCTest -final class PostgresCastingErrorTests: XCTestCase { - func testPostgresCastingErrorEquality() { - let error1 = PostgresCastingError( +final class PostgresDecodingErrorTests: XCTestCase { + func testPostgresDecodingErrorEquality() { + let error1 = PostgresDecodingError( code: .typeMismatch, columnName: "column", columnIndex: 0, @@ -15,7 +15,7 @@ final class PostgresCastingErrorTests: XCTestCase { line: 123 ) - let error2 = PostgresCastingError( + let error2 = PostgresDecodingError( code: .typeMismatch, columnName: "column", columnIndex: 0, @@ -32,8 +32,8 @@ final class PostgresCastingErrorTests: XCTestCase { XCTAssertEqual(error1, error3) } - func testPostgresCastingErrorDescription() { - let error = PostgresCastingError( + func testPostgresDecodingErrorDescription() { + let error = PostgresDecodingError( code: .typeMismatch, columnName: "column", columnIndex: 0, @@ -45,6 +45,6 @@ final class PostgresCastingErrorTests: XCTestCase { line: 123 ) - XCTAssertNotEqual("\(error)", "Database error") + XCTAssertEqual("\(error)", "Database error") } } diff --git a/dev/generate-postgresrow-multi-decode.sh b/dev/generate-postgresrow-multi-decode.sh index 2fb98c24..cebd9449 100755 --- a/dev/generate-postgresrow-multi-decode.sh +++ b/dev/generate-postgresrow-multi-decode.sh @@ -65,8 +65,8 @@ function gen() { echo -n ", r$(($n))" done echo ")" - echo " } catch let code as PostgresCastingError.Code {" - echo " throw PostgresCastingError(" + echo " } catch let code as PostgresDecodingError.Code {" + echo " throw PostgresDecodingError(" echo " code: code," echo " columnName: column.name," echo " columnIndex: columnIndex," From 5358acb5447ee898b9a84c7a48dccf9af726685c Mon Sep 17 00:00:00 2001 From: BennyDB <74614235+BennyDeBock@users.noreply.github.com> Date: Tue, 3 May 2022 15:29:48 +0200 Subject: [PATCH 105/292] Add project board workflow (#291) --- .github/workflows/projectboard.yml | 83 ++++++++---------------------- 1 file changed, 21 insertions(+), 62 deletions(-) diff --git a/.github/workflows/projectboard.yml b/.github/workflows/projectboard.yml index e4ff9c69..b857f6ee 100644 --- a/.github/workflows/projectboard.yml +++ b/.github/workflows/projectboard.yml @@ -1,72 +1,31 @@ -name: first-issues-to-beginner-issues-project +name: issue-to-project-board-workflow on: # Trigger when an issue gets labeled or deleted issues: types: [reopened, closed, labeled, unlabeled, assigned, unassigned] jobs: - manage_project_issues: - strategy: - fail-fast: false - matrix: - project: - - 'Beginner Issues' + setup_matrix_input: runs-on: ubuntu-latest - if: contains(github.event.issue.labels.*.name, 'good first issue') - steps: - # When an issue that is open is labeled, unassigned or reopened without a assigned member - # create or move the card to "To do" - - name: Create or Update Project Card - if: | - github.event.action == 'labeled' || - github.event.action == 'reopened' || - github.event.action == 'unassigned' - uses: alex-page/github-project-automation-plus@v0.8.1 - with: - project: ${{ matrix.project }} - column: 'To do' - repo-token: ${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }} - - # When an issue that is open is assigned and has an assigned member - # create or move the card to "In progress" - - name: Assign Project Card - if: | - github.event.action == 'assigned' - uses: alex-page/github-project-automation-plus@v0.8.1 - with: - project: ${{ matrix.project }} - column: 'In progress' - repo-token: ${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }} - # When an issue is closed with the good first issue tag - # Create or move the card to "Done" - - name: Close Project Card - if: | - github.event.action == 'closed' - uses: asmfnk/my-github-project-automation@v0.5.0 - with: - project: ${{ matrix.project }} - column: 'Done' - repo-token: ${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }} - - remove_project_issues: - strategy: - fail-fast: false - matrix: - project: - - 'Beginner Issues' - runs-on: ubuntu-latest - if: ${{ !contains(github.event.issue.labels.*.name, 'good first issue') }} steps: - # When an issue has the tag 'good first issue' removed - # Remove the card from the board - - name: Remove Project Card - if: | - github.event.action == 'unlabeled' - uses: alex-page/github-project-automation-plus@v0.8.1 - with: - project: ${{ matrix.project }} - column: 'To do' - repo-token: ${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }} - action: delete + - id: set-matrix + run: | + output=$(curl ${{ github.event.issue.url }}/labels | jq '.[] | .name') + echo '======================' + echo 'Process incoming data' + echo '======================' + json=$(echo $output | sed 's/"\s"/","/g') + echo $json + echo "::set-output name=matrix::$(echo $json)" + outputs: + issueTags: ${{ steps.set-matrix.outputs.matrix }} + + Manage_project_issues: + needs: setup_matrix_input + uses: vapor/ci/.github/workflows/issues-to-project-board.yml@main + with: + labelsJson: ${{ needs.setup_matrix_input.outputs.issueTags }} + secrets: + PROJECT_BOARD_AUTOMATION_PAT: "${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }}" From 2ddc2e1d5ae8e1262d7f7891c4c9fd1c5ee17ddd Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 16 May 2022 08:14:40 +0200 Subject: [PATCH 106/292] Update README.md (#293) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dc7e14ed..8fa322ec 100644 --- a/README.md +++ b/README.md @@ -114,9 +114,9 @@ let config = PostgresConnection.Configuration( ) let connection = try await PostgresConnection.connect( - on eventLoop: eventLoopGroup.next(), + on: eventLoopGroup.next(), configuration: config, - id connectionID: 1, + id: 1, logger: logger ) From 2825829d81ce98b11e30306e97af0d412e23a5da Mon Sep 17 00:00:00 2001 From: Thomas Rasch Date: Fri, 3 Jun 2022 16:20:49 +0200 Subject: [PATCH 107/292] Expose connectTimeout as a configuration option (#276) fixes: #273 Expose connectTimeout as a configuration option Co-authored-by: Fabian Fett --- .../PostgresNIO/Connection/PostgresConnection.swift | 13 +++++++++++-- .../New/PostgresChannelHandlerTests.swift | 4 +++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 08b5149e..70f61730 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -78,9 +78,15 @@ public final class PostgresConnection { /// - Default: 5432 public var port: Int + /// Specifies a timeout to apply to a connection attempt. + /// + /// - Default: 10 seconds + public var connectTimeout: TimeAmount + public init(host: String, port: Int = 5432) { self.host = host self.port = port + self.connectTimeout = .seconds(10) } } @@ -281,12 +287,12 @@ public final class PostgresConnection { ) -> NIOClientTCPBootstrapProtocol { #if canImport(Network) if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { - return tsBootstrap + return tsBootstrap.connectTimeout(configuration.connectTimeout) } #endif if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - return nioBootstrap + return nioBootstrap.connectTimeout(configuration.connectTimeout) } fatalError("No matching bootstrap found") @@ -393,6 +399,7 @@ extension PostgresConnection { return tlsFuture.flatMap { tls in let configuration = PostgresConnection.InternalConfiguration( connection: .resolved(address: socketAddress, serverName: serverHostname), + connectTimeout: .seconds(10), authentication: nil, tls: tls ) @@ -752,6 +759,7 @@ extension PostgresConnection { } var connection: Connection + var connectTimeout: TimeAmount var authentication: Configuration.Authentication? @@ -763,6 +771,7 @@ extension PostgresConnection.InternalConfiguration { init(_ config: PostgresConnection.Configuration) { self.authentication = config.authentication self.connection = .unresolved(host: config.connection.host, port: config.connection.port) + self.connectTimeout = config.connection.connectTimeout self.tls = config.tls } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index d3c2b10f..e2e73b46 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -173,7 +173,8 @@ class PostgresChannelHandlerTests: XCTestCase { username: String = "test", database: String = "postgres", password: String = "password", - tls: PostgresConnection.Configuration.TLS = .disable + tls: PostgresConnection.Configuration.TLS = .disable, + connectTimeout: TimeAmount = .seconds(10) ) -> PostgresConnection.InternalConfiguration { let authentication = PostgresConnection.Configuration.Authentication( username: username, @@ -183,6 +184,7 @@ class PostgresChannelHandlerTests: XCTestCase { return PostgresConnection.InternalConfiguration( connection: .unresolved(host: host, port: port), + connectTimeout: connectTimeout, authentication: authentication, tls: tls ) From d648c5b4594ffbc2f6173318f70f5531e05ccb4e Mon Sep 17 00:00:00 2001 From: Zach Rausnitz Date: Fri, 3 Jun 2022 19:48:43 -0400 Subject: [PATCH 108/292] Make backend key data optional (#296) --- .../Connection/PostgresConnection.swift | 12 ++++++- .../ConnectionStateMachine.swift | 21 ++++++------ .../New/PostgresChannelHandler.swift | 2 +- .../AuthenticationStateMachineTests.swift | 16 +++++----- .../ConnectionStateMachineTests.swift | 32 +++++++++++++++---- .../ConnectionAction+TestUtils.swift | 5 +-- .../New/PostgresChannelHandlerTests.swift | 6 ++-- 7 files changed, 62 insertions(+), 32 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 70f61730..1784dd19 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -78,6 +78,12 @@ public final class PostgresConnection { /// - Default: 5432 public var port: Int + /// Require connection to provide `BackendKeyData`. + /// For use with Amazon RDS Proxy, this must be set to false. + /// + /// - Default: true + public var requireBackendKeyData: Bool = true + /// Specifies a timeout to apply to a connection attempt. /// /// - Default: 10 seconds @@ -401,7 +407,8 @@ extension PostgresConnection { connection: .resolved(address: socketAddress, serverName: serverHostname), connectTimeout: .seconds(10), authentication: nil, - tls: tls + tls: tls, + requireBackendKeyData: true ) return PostgresConnection.connect( @@ -764,6 +771,8 @@ extension PostgresConnection { var authentication: Configuration.Authentication? var tls: Configuration.TLS + + var requireBackendKeyData: Bool } } @@ -773,6 +782,7 @@ extension PostgresConnection.InternalConfiguration { self.connection = .unresolved(host: config.connection.host, port: config.connection.port) self.connectTimeout = config.connection.connectTimeout self.tls = config.tls + self.requireBackendKeyData = config.connection.requireBackendKeyData } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 31a9ba1d..91e6c007 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -5,9 +5,7 @@ struct ConnectionStateMachine { typealias TransactionState = PostgresBackendMessage.TransactionState struct ConnectionContext { - let processID: Int32 - let secretKey: Int32 - + let backendKeyData: Optional var parameters: [String: String] var transactionState: TransactionState } @@ -113,17 +111,20 @@ struct ConnectionStateMachine { } private var state: State + private let requireBackendKeyData: Bool private var taskQueue = CircularBuffer() private var quiescingState: QuiescingState = .notQuiescing - init() { + init(requireBackendKeyData: Bool) { self.state = .initialized + self.requireBackendKeyData = requireBackendKeyData } #if DEBUG /// for testing purposes only - init(_ state: State) { + init(_ state: State, requireBackendKeyData: Bool = true) { self.state = state + self.requireBackendKeyData = requireBackendKeyData } #endif @@ -543,14 +544,12 @@ struct ConnectionStateMachine { mutating func readyForQueryReceived(_ transactionState: PostgresBackendMessage.TransactionState) -> ConnectionAction { switch self.state { case .authenticated(let backendKeyData, let parameters): - guard let keyData = backendKeyData else { - // `backendKeyData` must have been received, before receiving the first `readyForQuery` + if self.requireBackendKeyData && backendKeyData == nil { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) } let connectionContext = ConnectionContext( - processID: keyData.processID, - secretKey: keyData.secretKey, + backendKeyData: backendKeyData, parameters: parameters, transactionState: transactionState) @@ -1314,8 +1313,8 @@ extension ConnectionStateMachine.State: CustomDebugStringConvertible { extension ConnectionStateMachine.ConnectionContext: CustomDebugStringConvertible { var debugDescription: String { """ - (processID: \(self.processID), \ - secretKey: \(self.secretKey), \ + (processID: \(self.backendKeyData?.processID != nil ? String(self.backendKeyData!.processID) : "nil")), \ + secretKey: \(self.backendKeyData?.secretKey != nil ? String(self.backendKeyData!.secretKey) : "nil")), \ parameters: \(String(reflecting: self.parameters))) """ } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 348a9f21..089dbf7e 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -32,7 +32,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { logger: Logger, configureSSLCallback: ((Channel) throws -> Void)?) { - self.state = ConnectionStateMachine() + self.state = ConnectionStateMachine(requireBackendKeyData: configuration.requireBackendKeyData) self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index 238f4884..18fbc71b 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -7,7 +7,7 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticatePlaintext() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) @@ -17,7 +17,7 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticateMD5() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) @@ -28,7 +28,7 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticateMD5WithoutPassword() { let authContext = AuthContext(username: "test", password: nil, database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) @@ -39,7 +39,7 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticateOkAfterStartUpWithoutAuthChallenge() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) @@ -47,7 +47,7 @@ class AuthenticationStateMachineTests: XCTestCase { func testAuthenticationFailure() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) @@ -79,7 +79,7 @@ class AuthenticationStateMachineTests: XCTestCase { for (message, mechanism) in unsupported { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(message), @@ -98,7 +98,7 @@ class AuthenticationStateMachineTests: XCTestCase { for message in unexpected { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(message), @@ -125,7 +125,7 @@ class AuthenticationStateMachineTests: XCTestCase { for message in unexpected { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 4a63e31c..aeabc1fa 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -8,7 +8,7 @@ class ConnectionStateMachineTests: XCTestCase { func testStartup() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext)) @@ -17,7 +17,7 @@ class ConnectionStateMachineTests: XCTestCase { func testSSLStartupSuccess() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection) XCTAssertEqual(state.sslHandlerAdded(), .wait) @@ -30,7 +30,7 @@ class ConnectionStateMachineTests: XCTestCase { func testSSLStartupFailHandler() { struct SSLHandlerAddError: Error, Equatable {} - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection) let failError = PSQLError.failedToAddSSLHandler(underlying: SSLHandlerAddError()) @@ -38,7 +38,7 @@ class ConnectionStateMachineTests: XCTestCase { } func testTLSRequiredStartupSSLUnsupported() { - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) XCTAssertEqual(state.sslUnsupportedReceived(), @@ -46,7 +46,7 @@ class ConnectionStateMachineTests: XCTestCase { } func testTLSPreferredStartupSSLUnsupported() { - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .prefer), .sendSSLRequest) XCTAssertEqual(state.sslUnsupportedReceived(), .provideAuthenticationContext) @@ -92,7 +92,7 @@ class ConnectionStateMachineTests: XCTestCase { } func testReadyForQueryReceivedWithoutBackendKeyAfterAuthenticated() { - var state = ConnectionStateMachine(.authenticated(nil, [:])) + var state = ConnectionStateMachine(.authenticated(nil, [:]), requireBackendKeyData: true) XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) @@ -110,6 +110,24 @@ class ConnectionStateMachineTests: XCTestCase { .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.unexpectedBackendMessage(.readyForQuery(.idle)), closePromise: nil))) } + func testReadyForQueryReceivedWithoutUnneededBackendKeyAfterAuthenticated() { + var state = ConnectionStateMachine(.authenticated(nil, [:]), requireBackendKeyData: false) + + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) + XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + func testErrorIsIgnoredWhenClosingConnection() { // test ignore unclean shutdown when closing connection var stateIgnoreChannelError = ConnectionStateMachine(.closing) @@ -133,7 +151,7 @@ class ConnectionStateMachineTests: XCTestCase { let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self) - var state = ConnectionStateMachine() + var state = ConnectionStateMachine(requireBackendKeyData: true) let extendedQueryContext = ExtendedQueryContext( query: "Select version()", logger: .psqlTest, diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index fdc69b81..72420798 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -75,6 +75,8 @@ extension ConnectionStateMachine { } static func createConnectionContext(transactionState: PostgresBackendMessage.TransactionState = .idle) -> ConnectionContext { + let backendKeyData = BackendKeyData(processID: 2730, secretKey: 882037977) + let paramaters = [ "DateStyle": "ISO, MDY", "application_name": "", @@ -90,8 +92,7 @@ extension ConnectionStateMachine { ] return ConnectionContext( - processID: 2730, - secretKey: 882037977, + backendKeyData: backendKeyData, parameters: paramaters, transactionState: transactionState ) diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index e2e73b46..298595c7 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -174,7 +174,8 @@ class PostgresChannelHandlerTests: XCTestCase { database: String = "postgres", password: String = "password", tls: PostgresConnection.Configuration.TLS = .disable, - connectTimeout: TimeAmount = .seconds(10) + connectTimeout: TimeAmount = .seconds(10), + requireBackendKeyData: Bool = true ) -> PostgresConnection.InternalConfiguration { let authentication = PostgresConnection.Configuration.Authentication( username: username, @@ -186,7 +187,8 @@ class PostgresChannelHandlerTests: XCTestCase { connection: .unresolved(host: host, port: port), connectTimeout: connectTimeout, authentication: authentication, - tls: tls + tls: tls, + requireBackendKeyData: requireBackendKeyData ) } } From 2fa1cb1c136e4256daee191068b7795625733c2a Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 4 Jun 2022 14:51:39 +0200 Subject: [PATCH 109/292] Enable SwiftPackageIndex docc hosting (#297) --- .spi.yml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .spi.yml diff --git a/.spi.yml b/.spi.yml new file mode 100644 index 00000000..76fd1534 --- /dev/null +++ b/.spi.yml @@ -0,0 +1,6 @@ +version: 1 +builder: + configs: + - documentation_targets: + - PostgresNIO + From 4b8ec141f040602b2fcc8c84728fd8d1b93e9dd2 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 27 Jun 2022 07:41:48 -0500 Subject: [PATCH 110/292] Use provided logger consistently in `PostgresConnection.send(_:logger:)` (#299) --- Sources/PostgresNIO/Connection/PostgresConnection.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 1784dd19..4b429189 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -601,7 +601,7 @@ extension PostgresConnection: PostgresDatabase { } case .prepareQuery(let request): - resultFuture = self.prepareStatement(request.query, with: request.name, logger: self.logger).map { + resultFuture = self.prepareStatement(request.query, with: request.name, logger: logger).map { request.prepared = PreparedQuery(underlying: $0, database: self) } From 08226c5128feb54be78abe8fe3f9796478c1e6bd Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 27 Sep 2022 17:29:13 +0200 Subject: [PATCH 111/292] Fix new NIO warnings (#300) --- Package.swift | 4 +++- .../Connection/PostgresConnection.swift | 6 +++--- Sources/PostgresNIO/Data/PostgresData.swift | 6 +----- Sources/PostgresNIO/Data/PostgresRow.swift | 4 ---- Sources/PostgresNIO/New/Messages/DataRow.swift | 4 ---- Sources/PostgresNIO/New/PostgresCell.swift | 4 ---- .../New/PostgresRowSequenceTests.swift | 14 +++++++------- 7 files changed, 14 insertions(+), 28 deletions(-) diff --git a/Package.swift b/Package.swift index 44d4edef..9db97e22 100644 --- a/Package.swift +++ b/Package.swift @@ -13,7 +13,8 @@ let package = Package( .library(name: "PostgresNIO", targets: ["PostgresNIO"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.35.0"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.41.1"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), @@ -22,6 +23,7 @@ let package = Package( ], targets: [ .target(name: "PostgresNIO", dependencies: [ + .product(name: "Atomics", package: "swift-atomics"), .product(name: "Crypto", package: "swift-crypto"), .product(name: "Logging", package: "swift-log"), .product(name: "Metrics", package: "swift-metrics"), diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 4b429189..56684098 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -1,5 +1,5 @@ +import Atomics import NIOCore -import NIOConcurrencyHelpers #if canImport(Network) import NIOTransportServices #endif @@ -379,7 +379,7 @@ public final class PostgresConnection { // MARK: Connect extension PostgresConnection { - static let idGenerator = NIOAtomic.makeAtomic(value: 0) + static let idGenerator = ManagedAtomic(0) @available(*, deprecated, message: "Use the new connect method that allows you to connect and authenticate in a single step", @@ -412,7 +412,7 @@ extension PostgresConnection { ) return PostgresConnection.connect( - connectionID: idGenerator.add(1), + connectionID: self.idGenerator.wrappingIncrementThenLoad(ordering: .relaxed), configuration: configuration, logger: logger, on: eventLoop diff --git a/Sources/PostgresNIO/Data/PostgresData.swift b/Sources/PostgresNIO/Data/PostgresData.swift index 16d4b3ee..1ae8af2f 100644 --- a/Sources/PostgresNIO/Data/PostgresData.swift +++ b/Sources/PostgresNIO/Data/PostgresData.swift @@ -1,9 +1,5 @@ -#if swift(>=5.6) -@preconcurrency import NIOCore -#else import NIOCore -#endif -import Foundation +import struct Foundation.UUID public struct PostgresData: CustomStringConvertible, CustomDebugStringConvertible { public static var null: PostgresData { diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 028fe656..c766c383 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -1,8 +1,4 @@ -#if swift(>=5.6) -@preconcurrency import NIOCore -#else import NIOCore -#endif import class Foundation.JSONDecoder /// `PostgresRow` represents a single table row that is received from the server for a query or a prepared statement. diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index 0deb0043..d0b078c7 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -1,8 +1,4 @@ -#if swift(>=5.6) -@preconcurrency import NIOCore -#else import NIOCore -#endif /// A backend data row message. /// diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index 8d11c78b..f13833a9 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -1,8 +1,4 @@ -#if swift(>=5.6) -@preconcurrency import NIOCore -#else import NIOCore -#endif public struct PostgresCell: Equatable { public var bytes: ByteBuffer? diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 6d7bc24b..d0c1e0cf 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -1,5 +1,5 @@ +import Atomics import NIOEmbedded -import NIOConcurrencyHelpers import Dispatch import XCTest @testable import PostgresNIO @@ -445,22 +445,22 @@ final class PostgresRowSequenceTests: XCTestCase { final class MockRowDataSource: PSQLRowsDataSource { var requestCount: Int { - self._requestCount.load() + self._requestCount.load(ordering: .relaxed) } var cancelCount: Int { - self._cancelCount.load() + self._cancelCount.load(ordering: .relaxed) } - private let _requestCount = NIOAtomic.makeAtomic(value: 0) - private let _cancelCount = NIOAtomic.makeAtomic(value: 0) + private let _requestCount = ManagedAtomic(0) + private let _cancelCount = ManagedAtomic(0) func request(for stream: PSQLRowStream) { - self._requestCount.add(1) + self._requestCount.wrappingIncrement(ordering: .relaxed) } func cancel(for stream: PSQLRowStream) { - self._cancelCount.add(1) + self._cancelCount.wrappingIncrement(ordering: .relaxed) } } #endif From 382b0e10c077bd579d85934b255b33b7dd0220f0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 27 Sep 2022 19:33:39 +0200 Subject: [PATCH 112/292] Replace Lock with new NIOLock (#305) SwiftNIO `2.42.0` has deprecated `Lock` and replaced it with a new `NIOLock`. This commit removes all uses of `Lock` and replaces them with `NIOLock`. Further, require new package versions. --- Package.swift | 8 ++++---- Sources/PostgresNIO/New/PostgresRowSequence.swift | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Package.swift b/Package.swift index 9db97e22..dc4197a8 100644 --- a/Package.swift +++ b/Package.swift @@ -14,12 +14,12 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.41.1"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.42.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.13.1"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.22.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.4.4"), ], targets: [ .target(name: "PostgresNIO", dependencies: [ diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 2298c541..d86b6e8a 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -70,7 +70,7 @@ extension PostgresRowSequence { } final class AsyncStreamConsumer { - let lock = Lock() + let lock = NIOLock() let lookupTable: [String: Int] let columns: [RowDescription.Column] From bfd17ae4381a061e366c22fd763a7f80cc2d38f0 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Mon, 3 Oct 2022 15:19:56 +0330 Subject: [PATCH 113/292] Increase bind parameters limit (#298) --- .../Connection/PostgresConnection.swift | 6 +- Sources/PostgresNIO/New/Messages/Bind.swift | 4 +- .../New/Messages/ParameterDescription.swift | 5 +- Sources/PostgresNIO/New/Messages/Parse.swift | 2 +- .../PSQLIntegrationTests.swift | 56 +++++++++++++++++++ Tests/IntegrationTests/PostgresNIOTests.swift | 2 +- .../New/Messages/ParseTests.swift | 2 +- 7 files changed, 65 insertions(+), 12 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 56684098..552abe94 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -309,7 +309,7 @@ public final class PostgresConnection { private func queryStream(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.id)" - guard query.binds.count <= Int(Int16.max) else { + guard query.binds.count <= Int(UInt16.max) else { return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) } @@ -341,7 +341,7 @@ public final class PostgresConnection { } func execute(_ executeStatement: PSQLExecuteStatement, logger: Logger) -> EventLoopFuture { - guard executeStatement.binds.count <= Int(Int16.max) else { + guard executeStatement.binds.count <= Int(UInt16.max) else { return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) } let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) @@ -498,7 +498,7 @@ extension PostgresConnection { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.id)" - guard query.binds.count <= Int(Int16.max) else { + guard query.binds.count <= Int(UInt16.max) else { throw PSQLError.tooManyParameters } let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift index 9fc0445e..898018d4 100644 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ b/Sources/PostgresNIO/New/Messages/Bind.swift @@ -20,14 +20,14 @@ extension PostgresFrontendMessage { // zero to indicate that there are no parameters or that the parameters all use the // default format (text); or one, in which case the specified format code is applied // to all parameters; or it can equal the actual number of parameters. - buffer.writeInteger(Int16(self.bind.count)) + buffer.writeInteger(UInt16(self.bind.count)) // The parameter format codes. Each must presently be zero (text) or one (binary). self.bind.metadata.forEach { buffer.writeInteger($0.format.rawValue) } - buffer.writeInteger(Int16(self.bind.count)) + buffer.writeInteger(UInt16(self.bind.count)) var parametersCopy = self.bind.bytes buffer.writeBuffer(¶metersCopy) diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index 0d519583..1ccc91e5 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -7,10 +7,7 @@ extension PostgresBackendMessage { var dataTypes: [PostgresDataType] static func decode(from buffer: inout ByteBuffer) throws -> Self { - let parameterCount = try buffer.throwingReadInteger(as: Int16.self) - guard parameterCount >= 0 else { - throw PSQLPartialDecodingError.integerMustBePositiveOrNull(parameterCount) - } + let parameterCount = try buffer.throwingReadInteger(as: UInt16.self) var result = [PostgresDataType]() result.reserveCapacity(Int(parameterCount)) diff --git a/Sources/PostgresNIO/New/Messages/Parse.swift b/Sources/PostgresNIO/New/Messages/Parse.swift index 268ad4ff..9d3cfa0b 100644 --- a/Sources/PostgresNIO/New/Messages/Parse.swift +++ b/Sources/PostgresNIO/New/Messages/Parse.swift @@ -15,7 +15,7 @@ extension PostgresFrontendMessage { func encode(into buffer: inout ByteBuffer) { buffer.writeNullTerminatedString(self.preparedStatementName) buffer.writeNullTerminatedString(self.query) - buffer.writeInteger(Int16(self.parameters.count)) + buffer.writeInteger(UInt16(self.parameters.count)) self.parameters.forEach { dataType in buffer.writeInteger(dataType.rawValue) diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 38443c5f..5339d3f8 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -329,4 +329,60 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(obj?.bar, 2) } } + +#if swift(>=5.5.2) + func testBindMaximumParameters() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await withTestConnection(on: eventLoop) { connection in + // Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257 + // Max columns limit is 1664, so we will only make 5 * 257 columns which is less + // Then we will insert 3 * 17 rows + // In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings + // If the test is successful, it means Postgres supports UInt16.max bindings + let columnsCount = 5 * 257 + let rowsCount = 3 * 17 + + let createQuery = PostgresQuery( + unsafeSQL: """ + CREATE TABLE table1 ( + \((0.. String in + "$\(rowIndex * columnsCount + columnIndex + 1)" + } + return "(\(indices.joined(separator: ", ")))" + }.joined(separator: ", ") + let insertionQuery = PostgresQuery( + unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)", + binds: binds + ) + try await connection.query(insertionQuery, logger: .psqlTest) + + let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1") + let countRows = try await connection.query(countQuery, logger: .psqlTest) + var countIterator = countRows.makeAsyncIterator() + let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default) + XCTAssertEqual(rowsCount, insertedRowsCount) + + let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1") + try await connection.query(dropQuery, logger: .psqlTest) + } + } +#endif } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 4ff68806..b455fdef 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1061,7 +1061,7 @@ final class PostgresNIOTests: XCTestCase { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - let binds = [PostgresData].init(repeating: .null, count: Int(Int16.max) + 1) + let binds = [PostgresData].init(repeating: .null, count: Int(UInt16.max) + 1) XCTAssertThrowsError(try conn?.query("SELECT version()", binds).wait()) { error in guard case .tooManyParameters = (error as? PSQLError)?.base else { return XCTFail("Unexpected error: \(error)") diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 3d562473..723ad1e6 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -26,7 +26,7 @@ class ParseTests: XCTestCase { XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName) XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query) - XCTAssertEqual(byteBuffer.readInteger(as: Int16.self), Int16(parse.parameters.count)) + XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parse.parameters.count)) XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bool.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.int8.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bytea.rawValue) From 17ba80f29c61e0d17cc02d0c91b647632b181380 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Mon, 3 Oct 2022 17:32:39 +0330 Subject: [PATCH 114/292] Update test CI (#306) * Update test.yml * exclude nightlies from codecov --- .github/workflows/test.yml | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 10190870..ff2c8fe7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,14 +12,12 @@ jobs: strategy: fail-fast: false matrix: - swift: - - swift:5.4 - - swift:5.5 - - swift:5.6 - - swiftlang/swift:nightly-main - os: - - focal - container: ${{ format('{0}-{1}', matrix.swift, matrix.os) }} + container: + - swift:5.5-bionic + - swift:5.6-focal + - swift:5.7-jammy + - swiftlang/swift:nightly-main-jammy + container: ${{ matrix.container }} runs-on: ubuntu-latest env: LOG_LEVEL: debug @@ -29,6 +27,7 @@ jobs: - name: Run unit tests with code coverage and Thread Sanitizer run: swift test --enable-test-discovery --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage - name: Submit coverage report to Codecov.io + if: "!contains(matrix.container, 'nightly')" uses: vapor/swift-codecov-action@v0.2 with: cc_flags: 'unittests' @@ -53,7 +52,7 @@ jobs: dbauth: md5 - dbimage: postgres:11 dbauth: trust - container: swift:5.6-focal + container: swift:5.7-jammy runs-on: ubuntu-latest env: LOG_LEVEL: debug @@ -153,8 +152,7 @@ jobs: api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest - container: - image: swift:5.6-focal + container: swift:5.7-jammy steps: - name: Checkout uses: actions/checkout@v3 From 2cad52aa6a08c32103340b631c6b6b339c3a120f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 3 Oct 2022 16:14:29 +0200 Subject: [PATCH 115/292] Remove need to supply `PostgresDecodingContext` when decoding Rows (#307) --- .../New/PostgresRow-multi-decode.swift | 90 +++++++++++++ .../PostgresRowSequence-multi-decode.swift | 122 +++++++++++++++--- dev/generate-postgresrow-multi-decode.sh | 42 +++++- ...nerate-postgresrowsequence-multi-decode.sh | 46 ++++++- 4 files changed, 278 insertions(+), 22 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift index d5386b08..4fe396ec 100644 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift @@ -31,6 +31,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0).Type, file: String = #file, line: Int = #line) throws -> (T0) { + try self.decode(T0.self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1) { @@ -67,6 +73,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1).Type, file: String = #file, line: Int = #line) throws -> (T0, T1) { + try self.decode((T0, T1).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { @@ -109,6 +121,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { + try self.decode((T0, T1, T2).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { @@ -157,6 +175,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { + try self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { @@ -211,6 +235,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { + try self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { @@ -271,6 +301,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { + try self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { @@ -337,6 +373,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { + try self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { @@ -409,6 +451,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { + try self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { @@ -487,6 +535,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { + try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { @@ -571,6 +625,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { + try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { @@ -661,6 +721,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { + try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { @@ -757,6 +823,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { + try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { @@ -859,6 +931,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { + try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { @@ -967,6 +1045,12 @@ extension PostgresRow { } } + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { + try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) + } + @inlinable @_alwaysEmitIntoClient public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { @@ -1080,4 +1164,10 @@ extension PostgresRow { ) } } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { + try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) + } } diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift index 0b3302c1..6dc3d9f1 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift @@ -1,10 +1,10 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh #if swift(>=5.5) && canImport(_Concurrency) -extension PostgresRowSequence { +extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode(T0.self, context: context, file: file, line: line) } @@ -12,7 +12,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode(T0.self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1).self, context: context, file: file, line: line) } @@ -20,7 +26,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2).self, context: context, file: file, line: line) } @@ -28,7 +40,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3).self, context: context, file: file, line: line) } @@ -36,7 +54,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4).self, context: context, file: file, line: line) } @@ -44,7 +68,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5).self, context: context, file: file, line: line) } @@ -52,7 +82,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6).self, context: context, file: file, line: line) } @@ -60,7 +96,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: context, file: file, line: line) } @@ -68,7 +110,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: context, file: file, line: line) } @@ -76,7 +124,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: context, file: file, line: line) } @@ -84,7 +138,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: context, file: file, line: line) } @@ -92,7 +152,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: context, file: file, line: line) } @@ -100,7 +166,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: context, file: file, line: line) } @@ -108,7 +180,13 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: context, file: file, line: line) } @@ -116,10 +194,22 @@ extension PostgresRowSequence { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) + } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: context, file: file, line: line) } } + + @inlinable + @_alwaysEmitIntoClient + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) + } } #endif diff --git a/dev/generate-postgresrow-multi-decode.sh b/dev/generate-postgresrow-multi-decode.sh index cebd9449..64a37417 100755 --- a/dev/generate-postgresrow-multi-decode.sh +++ b/dev/generate-postgresrow-multi-decode.sh @@ -4,7 +4,7 @@ set -eu here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -function gen() { +function genWithContextParameter() { how_many=$1 if [[ $how_many -ne 1 ]] ; then @@ -81,6 +81,43 @@ function gen() { echo " }" } +function genWithoutContextParameter() { + how_many=$1 + + echo "" + + echo " @inlinable" + echo " @_alwaysEmitIntoClient" + echo -n " public func decode(_: (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").Type, file: String = #file, line: Int = #line) throws" + + echo -n " -> (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo ") {" + echo -n " try self.decode(" + if [[ $how_many -eq 1 ]] ; then + echo -n "T0.self" + else + echo -n "(T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").self" + fi + echo ", context: .default, file: file, line: line)" + echo " }" +} + grep -q "ByteBuffer" "${BASH_SOURCE[0]}" || { echo >&2 "ERROR: ${BASH_SOURCE[0]}: file or directory not found (this should be this script)" exit 1 @@ -98,7 +135,8 @@ echo "extension PostgresRow {" # - widening the inverval below (eg. going from {1..15} to {1..25}) is Semver minor # - narrowing the interval below is SemVer _MAJOR_! for n in {1..15}; do - gen "$n" + genWithContextParameter "$n" + genWithoutContextParameter "$n" done echo "}" } > "$here/../Sources/PostgresNIO/New/PostgresRow-multi-decode.swift" diff --git a/dev/generate-postgresrowsequence-multi-decode.sh b/dev/generate-postgresrowsequence-multi-decode.sh index 284b0049..f4a29c95 100755 --- a/dev/generate-postgresrowsequence-multi-decode.sh +++ b/dev/generate-postgresrowsequence-multi-decode.sh @@ -4,7 +4,7 @@ set -eu here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -function gen() { +function genWithContextParameter() { how_many=$1 if [[ $how_many -ne 1 ]] ; then @@ -24,7 +24,7 @@ function gen() { done echo -n ").Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) " - echo -n "-> AsyncThrowingMapSequence AsyncThrowingMapSequence(_: (T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").Type, file: String = #file, line: Int = #line) " + echo -n "-> AsyncThrowingMapSequence {" + + echo -n " self.decode(" + if [[ $how_many -eq 1 ]] ; then + echo -n "T0.self" + else + echo -n "(T0" + for ((n = 1; n<$how_many; n +=1)); do + echo -n ", T$(($n))" + done + echo -n ").self" + fi + echo ", context: .default, file: file, line: line)" + echo " }" +} + grep -q "ByteBuffer" "${BASH_SOURCE[0]}" || { echo >&2 "ERROR: ${BASH_SOURCE[0]}: file or directory not found (this should be this script)" exit 1 @@ -60,13 +97,14 @@ EOF echo echo "#if swift(>=5.5) && canImport(_Concurrency)" -echo "extension PostgresRowSequence {" +echo "extension AsyncSequence where Element == PostgresRow {" # note: # - widening the inverval below (eg. going from {1..15} to {1..25}) is Semver minor # - narrowing the interval below is SemVer _MAJOR_! for n in {1..15}; do - gen "$n" + genWithContextParameter "$n" + genWithoutContextParameter "$n" done echo "}" echo "#endif" From e277f93132634c26659222fad891df667ef705c7 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 5 Oct 2022 15:04:21 +0200 Subject: [PATCH 116/292] Raise minimum Swift version to 5.5 (#310) --- Package.swift | 2 +- README.md | 4 +- .../Connection/PostgresConnection.swift | 2 +- Sources/PostgresNIO/New/PSQLRowStream.swift | 10 ++-- .../PostgresRowSequence-multi-decode.swift | 2 +- .../PostgresNIO/New/PostgresRowSequence.swift | 2 +- Tests/IntegrationTests/AsyncTests.swift | 56 ++++++++++++++++++- .../PSQLIntegrationTests.swift | 55 ------------------ .../New/PostgresRowSequenceTests.swift | 2 +- ...nerate-postgresrowsequence-multi-decode.sh | 2 +- 10 files changed, 68 insertions(+), 69 deletions(-) diff --git a/Package.swift b/Package.swift index dc4197a8..03ba1887 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.4 +// swift-tools-version:5.5 import PackageDescription let package = Package( diff --git a/README.md b/README.md index 8fa322ec..9a6b3f47 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![Team Chat](https://img.shields.io/discord/431917998102675485.svg)][Team Chat] [![MIT License](http://img.shields.io/badge/license-MIT-brightgreen.svg)][MIT License] [![Continuous Integration](https://github.com/vapor/postgres-nio/actions/workflows/test.yml/badge.svg)][Continuous Integration] -[![Swift 5.4](http://img.shields.io/badge/swift-5.4-brightgreen.svg)][Swift 5.4] +[![Swift 5.5](http://img.shields.io/badge/swift-5.5-brightgreen.svg)][Swift 5.5]

@@ -191,7 +191,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.4]: https://swift.org +[Swift 5.5]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresConnection/ diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 552abe94..ac533c6e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -449,7 +449,7 @@ extension PostgresConnection { // MARK: Async/Await Interface -#if swift(>=5.5) && canImport(_Concurrency) +#if canImport(_Concurrency) extension PostgresConnection { /// Creates a new connection to a Postgres server. diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 58730851..c73cda20 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -22,7 +22,7 @@ final class PSQLRowStream { case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) case consumed(Result) - #if swift(>=5.5) && canImport(_Concurrency) + #if canImport(_Concurrency) case asyncSequence(AsyncStreamConsumer, PSQLRowsDataSource) #endif } @@ -63,7 +63,7 @@ final class PSQLRowStream { // MARK: Async Sequence - #if swift(>=5.5) && canImport(_Concurrency) + #if canImport(_Concurrency) func asyncSequence() -> PostgresRowSequence { self.eventLoop.preconditionInEventLoop() @@ -304,7 +304,7 @@ final class PSQLRowStream { // immediately request more dataSource.request(for: self) - #if swift(>=5.5) && canImport(_Concurrency) + #if canImport(_Concurrency) case .asyncSequence(let consumer, _): consumer.receive(newRows) #endif @@ -344,7 +344,7 @@ final class PSQLRowStream { self.downstreamState = .consumed(.success(commandTag)) promise.succeed(rows) - #if swift(>=5.5) && canImport(_Concurrency) + #if canImport(_Concurrency) case .asyncSequence(let consumer, _): consumer.receive(completion: .success(commandTag)) self.downstreamState = .consumed(.success(commandTag)) @@ -371,7 +371,7 @@ final class PSQLRowStream { self.downstreamState = .consumed(.failure(error)) promise.fail(error) - #if swift(>=5.5) && canImport(_Concurrency) + #if canImport(_Concurrency) case .asyncSequence(let consumer, _): consumer.receive(completion: .failure(error)) self.downstreamState = .consumed(.failure(error)) diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift index 6dc3d9f1..d7429ff8 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift @@ -1,6 +1,6 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh -#if swift(>=5.5) && canImport(_Concurrency) +#if canImport(_Concurrency) extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index d86b6e8a..2e366432 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -1,7 +1,7 @@ import NIOCore import NIOConcurrencyHelpers -#if swift(>=5.5) && canImport(_Concurrency) +#if canImport(_Concurrency) /// An async sequence of ``PostgresRow``s. /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index cb6950d6..9d43397f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -5,7 +5,7 @@ import PostgresNIO import NIOTransportServices #endif -#if swift(>=5.5.2) +#if canImport(_Concurrency) final class AsyncPostgresConnectionTests: XCTestCase { func test1kRoundTrips() async throws { @@ -64,6 +64,60 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testBindMaximumParameters() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await withTestConnection(on: eventLoop) { connection in + // Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257 + // Max columns limit is 1664, so we will only make 5 * 257 columns which is less + // Then we will insert 3 * 17 rows + // In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings + // If the test is successful, it means Postgres supports UInt16.max bindings + let columnsCount = 5 * 257 + let rowsCount = 3 * 17 + + let createQuery = PostgresQuery( + unsafeSQL: """ + CREATE TABLE table1 ( + \((0.. String in + "$\(rowIndex * columnsCount + columnIndex + 1)" + } + return "(\(indices.joined(separator: ", ")))" + }.joined(separator: ", ") + let insertionQuery = PostgresQuery( + unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)", + binds: binds + ) + try await connection.query(insertionQuery, logger: .psqlTest) + + let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1") + let countRows = try await connection.query(countQuery, logger: .psqlTest) + var countIterator = countRows.makeAsyncIterator() + let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default) + XCTAssertEqual(rowsCount, insertedRowsCount) + + let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1") + try await connection.query(dropQuery, logger: .psqlTest) + } + } + #if canImport(Network) func testSelect10kRowsNetworkFramework() async throws { let eventLoopGroup = NIOTSEventLoopGroup() diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 5339d3f8..5debde90 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -330,59 +330,4 @@ final class IntegrationTests: XCTestCase { } } -#if swift(>=5.5.2) - func testBindMaximumParameters() async throws { - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } - let eventLoop = eventLoopGroup.next() - - try await withTestConnection(on: eventLoop) { connection in - // Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257 - // Max columns limit is 1664, so we will only make 5 * 257 columns which is less - // Then we will insert 3 * 17 rows - // In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings - // If the test is successful, it means Postgres supports UInt16.max bindings - let columnsCount = 5 * 257 - let rowsCount = 3 * 17 - - let createQuery = PostgresQuery( - unsafeSQL: """ - CREATE TABLE table1 ( - \((0.. String in - "$\(rowIndex * columnsCount + columnIndex + 1)" - } - return "(\(indices.joined(separator: ", ")))" - }.joined(separator: ", ") - let insertionQuery = PostgresQuery( - unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)", - binds: binds - ) - try await connection.query(insertionQuery, logger: .psqlTest) - - let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1") - let countRows = try await connection.query(countQuery, logger: .psqlTest) - var countIterator = countRows.makeAsyncIterator() - let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default) - XCTAssertEqual(rowsCount, insertedRowsCount) - - let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1") - try await connection.query(dropQuery, logger: .psqlTest) - } - } -#endif } diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index d0c1e0cf..a8e20d76 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -4,7 +4,7 @@ import Dispatch import XCTest @testable import PostgresNIO -#if swift(>=5.5.2) +#if canImport(_Concurrency) final class PostgresRowSequenceTests: XCTestCase { func testBackpressureWorks() async throws { diff --git a/dev/generate-postgresrowsequence-multi-decode.sh b/dev/generate-postgresrowsequence-multi-decode.sh index f4a29c95..126f2a61 100755 --- a/dev/generate-postgresrowsequence-multi-decode.sh +++ b/dev/generate-postgresrowsequence-multi-decode.sh @@ -96,7 +96,7 @@ cat <<"EOF" EOF echo -echo "#if swift(>=5.5) && canImport(_Concurrency)" +echo "#if canImport(_Concurrency)" echo "extension AsyncSequence where Element == PostgresRow {" # note: From a12d09fdf93b8c0a1a48f75b05639a7a18d07565 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 7 Oct 2022 14:00:11 +0200 Subject: [PATCH 117/292] Add docc catalog (#311) --- Sources/PostgresNIO/Docs.docc/index.md | 75 ++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 Sources/PostgresNIO/Docs.docc/index.md diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md new file mode 100644 index 00000000..a16e62d9 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -0,0 +1,75 @@ +# ``PostgresNIO`` + +🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. + +## Overview + +Features: + +- A ``PostgresConnection`` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server +- An async/await interface that supports backpressure +- Automatic conversions between Swift primitive types and the Postgres wire format +- Integrated with the Swift server ecosystem, including use of [SwiftLog]. +- Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) +- Support for `Network.framework` when available (e.g. on Apple platforms) + +## Topics + +### Connections + +- ``PostgresConnection`` + +### Querying + +- ``PostgresQuery`` +- ``PostgresBindings`` +- ``PostgresRow`` +- ``PostgresRowSequence`` +- ``PostgresRandomAccessRow`` +- ``PostgresCell`` +- ``PreparedQuery`` +- ``PostgresQueryMetadata`` + +### Encoding and Decoding + +- ``PostgresEncodable`` +- ``PostgresEncodingContext`` +- ``PostgresDecodable`` +- ``PostgresDecodingContext`` +- ``PostgresArrayEncodable`` +- ``PostgresArrayDecodable`` +- ``PostgresJSONEncoder`` +- ``PostgresJSONDecoder`` +- ``PostgresDataType`` +- ``PostgresNumeric`` + +### Notifications + +- ``PostgresListenContext`` + +### Errors + +- ``PostgresError`` +- ``PostgresDecodingError`` + +### Deprecated + +These types are already deprecated or will be deprecated in the near future. All of them will be +removed from the public API with the next major release. + +- ``PostgresDatabase`` +- ``PostgresData`` +- ``PostgresDataConvertible`` +- ``PostgresQueryResult`` +- ``PostgresJSONCodable`` +- ``PostgresJSONBCodable`` +- ``PostgresMessageEncoder`` +- ``PostgresMessageDecoder`` +- ``PostgresRequest`` +- ``PostgresMessage`` +- ``PostgresMessageType`` +- ``SASLAuthenticationManager`` +- ``SASLAuthenticationMechanism`` + +[SwiftNIO]: https://github.com/apple/swift-nio +[SwiftLog]: https://github.com/apple/swift-log From c2cdd473ead90c588aa0753b4a09ba2887949e99 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 7 Oct 2022 17:48:13 +0200 Subject: [PATCH 118/292] Add docs outlining changes we made to `PostgresRow/column(name:)` (#312) Co-authored-by: Tim Condon <0xTim@users.noreply.github.com> --- Sources/PostgresNIO/Data/PostgresRow.swift | 2 +- Sources/PostgresNIO/Docs.docc/index.md | 4 + Sources/PostgresNIO/Docs.docc/migrations.md | 102 ++++++++++++++++++++ Sources/PostgresNIO/New/PostgresCell.swift | 40 +++++++- 4 files changed, 145 insertions(+), 3 deletions(-) create mode 100644 Sources/PostgresNIO/Docs.docc/migrations.md diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index c766c383..914667e5 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -131,7 +131,7 @@ public struct PostgresRandomAccessRow { let cells: [ByteBuffer?] let lookupTable: [String: Int] - init(_ row: PostgresRow) { + public init(_ row: PostgresRow) { self.cells = [ByteBuffer?](row.data) self.columns = row.columns self.lookupTable = row.lookupTable diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md index a16e62d9..6b7fd5b0 100644 --- a/Sources/PostgresNIO/Docs.docc/index.md +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -15,6 +15,10 @@ Features: ## Topics +### Articles + +- + ### Connections - ``PostgresConnection`` diff --git a/Sources/PostgresNIO/Docs.docc/migrations.md b/Sources/PostgresNIO/Docs.docc/migrations.md new file mode 100644 index 00000000..33c8afd4 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/migrations.md @@ -0,0 +1,102 @@ +# Adopting the new PostgresRow cell API + +This article describes how to adopt the new ``PostgresRow`` cell APIs in existing Postgres code +which use the ``PostgresRow/column(_:)`` API today. + +## TLDR + +1. Map your sequence of ``PostgresRow``s to ``PostgresRandomAccessRow``s. +2. Use the ``PostgresRandomAccessRow/subscript(name:)`` API to receive a ``PostgresCell`` +3. Decode the ``PostgresCell`` into a Swift type using the ``PostgresCell/decode(_:file:line:)`` method. + +```swift +let rows: [PostgresRow] // your existing return value +for row in rows.map({ PostgresRandomAccessRow($0) }) { + let id = try row["id"].decode(UUID.self) + let name = try row["name"].decode(String.self) + let email = try row["email"].decode(String.self) + let age = try row["age"].decode(Int.self) +} +``` + +## Overview + +When Postgres [`1.9.0`] was released we changed the default behaviour of ``PostgresRow``s. +Previously for each row we created an internal lookup table, that allowed you to access the rows' +cells by name: + +```swift +connection.query("SELECT id, name, email, age FROM users").whenComplete { + switch $0 { + case .success(let result): + for row in result.rows { + let id = row.column("id").uuid + let name = row.column("name").string + let email = row.column("email").string + let age = row.column("age").int + // do further processing + } + case .failure(let error): + // handle the error + } +} +``` + +During the last year we introduced a new API that let's you consume ``PostgresRow`` by iterating +its cells. This approach has the performance benefit of not needing to create an internal cell +lookup table for each row: + +```swift +connection.query("SELECT id, name, email, age FROM users").whenComplete { + switch $0 { + case .success(let result): + for row in result.rows { + let (id, name, email, age) = try row.decode((UUID, String, String, Int).self) + // do further processing + } + case .failure(let error): + // handle the error + } +} +``` + +However, since we still supported the ``PostgresRow/column(_:)`` API, which requires a precomputed +lookup table within the row, users were not seeing any performance benefits. To allow users to +benefit of the new fastpath, we changed ``PostgresRow``'s behavior: + +By default the ``PostgresRow`` does not create an internal lookup table for its cells on creation +anymore. Because of this, when using the ``PostgresRow/column(_:)`` API, a throwaway lookup table +needs to be produced on every call. Since this is wasteful we have deprecated this API. Instead we +allow users now to explicitly opt-in into the cell lookup API by using the new +``PostgresRandomAccessRow``. + +```swift +connection.query("SELECT id, name, email, age FROM users").whenComplete { + switch $0 { + case .success(let result): + for row in result.rows.map { PostgresRandomAccessRow($0) } { + let id = try row["id"].decode(UUID.self) + let name = try row["name"].decode(String.self) + let email = try row["email"].decode(String.self) + let age = try row["age"].decode(Int.self) + // do further processing + } + case .failure(let error): + // handle the error + } +} +``` + +## Topics + +### Relevant types + +- ``PostgresConnection`` +- ``PostgresQuery`` +- ``PostgresBindings`` +- ``PostgresRow`` +- ``PostgresRandomAccessRow`` +- ``PostgresEncodable`` +- ``PostgresDecodable`` + +[`1.9.0`]: https://github.com/vapor/postgres-nio/releases/tag/1.9.0 diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index f13833a9..39710e8e 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -1,14 +1,26 @@ import NIOCore +/// A representation of a cell value within a ``PostgresRow`` and ``PostgresRandomAccessRow``. public struct PostgresCell: Equatable { + /// The cell's value as raw bytes. public var bytes: ByteBuffer? + /// The cell's data type. This is important metadata when decoding the cell. public var dataType: PostgresDataType + /// The format in which the cell's bytes are encoded. public var format: PostgresFormat + /// The cell's column name within the row. public var columnName: String + /// The cell's column index within the row. public var columnIndex: Int - init(bytes: ByteBuffer?, dataType: PostgresDataType, format: PostgresFormat, columnName: String, columnIndex: Int) { + public init( + bytes: ByteBuffer?, + dataType: PostgresDataType, + format: PostgresFormat, + columnName: String, + columnIndex: Int + ) { self.bytes = bytes self.dataType = dataType self.format = format @@ -19,7 +31,14 @@ public struct PostgresCell: Equatable { } extension PostgresCell { - + /// Decode the cell into a Swift type, that conforms to ``PostgresDecodable`` + /// + /// - Parameters: + /// - _: The Swift type, which conforms to ``PostgresDecodable``, to decode from the cell's ``PostgresCell/bytes`` values. + /// - context: A ``PostgresDecodingContext`` to supply a custom ``PostgresJSONDecoder`` for decoding JSON fields. + /// - file: The source file in which this method was called. Used in the error case in ``PostgresDecodingError``. + /// - line: The source file line in which this method was called. Used in the error case in ``PostgresDecodingError``. + /// - Returns: A decoded Swift type. @inlinable public func decode( _: T.Type, @@ -49,6 +68,23 @@ extension PostgresCell { ) } } + + + /// Decode the cell into a Swift type, that conforms to ``PostgresDecodable`` + /// + /// - Parameters: + /// - _: The Swift type, which conforms to ``PostgresDecodable``, to decode from the cell's ``PostgresCell/bytes`` values. + /// - file: The source file in which this method was called. Used in the error case in ``PostgresDecodingError``. + /// - line: The source file line in which this method was called. Used in the error case in ``PostgresDecodingError``. + /// - Returns: A decoded Swift type. + @inlinable + public func decode( + _: T.Type, + file: String = #file, + line: Int = #line + ) throws -> T { + try self.decode(T.self, context: .default, file: file, line: line) + } } #if swift(>=5.6) From f636f59ed023ab5629fa5d5a44c8a1148fde1bad Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 19 Oct 2022 08:55:38 +0200 Subject: [PATCH 119/292] Link to swiftpackageindex for documentation (#308) --- README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 9a6b3f47..b82200b4 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ PostgresNIO does not provide a `ConnectionPool` as of today, but this is a [feat ## API Docs -Check out the [PostgresNIO API docs](https://api.vapor.codes/postgres-nio/main/PostgresNIO/) for a +Check out the [PostgresNIO API docs][Documentation] for a detailed look at all of the classes, structs, protocols, and more. ## Getting started @@ -187,20 +187,20 @@ Some queries do not receive any rows from the server (most often `INSERT`, `UPDA Please see [SECURITY.md] for details on the security process. [SSWG Incubation]: https://github.com/swift-server/sswg/blob/main/process/incubation.md#graduated-level -[Documentation]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/ +[Documentation]: https://swiftpackageindex.com/vapor/postgres-nio/documentation [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions [Swift 5.5]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md -[`PostgresConnection`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresConnection/ -[`query(_:logger:)`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresConnection/#postgresconnection.query(_:logger:file:line:) -[`PostgresQuery`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresQuery/ -[`PostgresRow`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresRow/ -[`PostgresRowSequence`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresRowSequence/ -[`PostgresDecodable`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresDecodable/ -[`PostgresEncodable`]: https://api.vapor.codes/postgres-nio/main/PostgresNIO/PostgresEncodable/ +[`PostgresConnection`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/ +[`query(_:logger:)`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/query(_:logger:file:line:)-9mkfn +[`PostgresQuery`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresquery/ +[`PostgresRow`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresrow/ +[`PostgresRowSequence`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresrowsequence/ +[`PostgresDecodable`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresdecodable/ +[`PostgresEncodable`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresencodable/ [PostgresKit]: https://github.com/vapor/postgres-kit From ab1fc3ced1b57c31040124b0f73f5e3522780447 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 26 Oct 2022 14:22:25 +0200 Subject: [PATCH 120/292] AsyncStreamConsumer should stream DataRows. (#316) --- .../PostgresNIO/New/PostgresRowSequence.swift | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 2e366432..5c18e43a 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -43,12 +43,24 @@ extension PostgresRowSequence { let _internal: _Internal + let lookupTable: [String: Int] + let columns: [RowDescription.Column] + init(consumer: AsyncStreamConsumer) { self._internal = _Internal(consumer: consumer) + self.lookupTable = consumer.lookupTable + self.columns = consumer.columns } public mutating func next() async throws -> PostgresRow? { - try await self._internal.next() + if let dataRow = try await self._internal.next() { + return PostgresRow( + data: dataRow, + lookupTable: self.lookupTable, + columns: columns + ) + } + return nil } final class _Internal { @@ -62,7 +74,7 @@ extension PostgresRowSequence { self.consumer.iteratorDeinitialized() } - func next() async throws -> PostgresRow? { + func next() async throws -> DataRow? { try await self.consumer.next() } } @@ -111,12 +123,7 @@ final class AsyncStreamConsumer { switch receiveAction { case .succeed(let continuation, let data, signalDemandTo: let source): - let row = PostgresRow( - data: data, - lookupTable: self.lookupTable, - columns: self.columns - ) - continuation.resume(returning: row) + continuation.resume(returning: data) source?.demand() case .none: @@ -175,7 +182,7 @@ final class AsyncStreamConsumer { } } - func next() async throws -> PostgresRow? { + func next() async throws -> DataRow? { self.lock.lock() switch self.state.next() { case .returnNil: @@ -185,11 +192,7 @@ final class AsyncStreamConsumer { case .returnRow(let data, signalDemandTo: let source): self.lock.unlock() source?.demand() - return PostgresRow( - data: data, - lookupTable: self.lookupTable, - columns: self.columns - ) + return data case .throwError(let error): self.lock.unlock() @@ -216,7 +219,7 @@ extension AsyncStreamConsumer { private enum UpstreamState { enum DemandState { case canAskForMore - case waitingForMore(CheckedContinuation?) + case waitingForMore(CheckedContinuation?) } case initialized @@ -395,7 +398,7 @@ extension AsyncStreamConsumer { case none } - mutating func next(for continuation: CheckedContinuation) -> NextSlowPathAction { + mutating func next(for continuation: CheckedContinuation) -> NextSlowPathAction { switch self.upstreamState { case .initialized: preconditionFailure() @@ -422,7 +425,7 @@ extension AsyncStreamConsumer { } enum ReceiveAction { - case succeed(CheckedContinuation, DataRow, signalDemandTo: PSQLRowStream?) + case succeed(CheckedContinuation, DataRow, signalDemandTo: PSQLRowStream?) case none } @@ -462,8 +465,8 @@ extension AsyncStreamConsumer { } enum CompletionResult { - case succeed(CheckedContinuation) - case fail(CheckedContinuation, Error) + case succeed(CheckedContinuation) + case fail(CheckedContinuation, Error) case none } From a365a9b0fe28955d89fec63ccf1eb23875c83a7d Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Wed, 26 Oct 2022 07:53:25 -0500 Subject: [PATCH 121/292] CI update for PostgreSQL 15 (#318) --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ff2c8fe7..c1d82648 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,11 +42,11 @@ jobs: fail-fast: false matrix: dbimage: - - postgres:14 + - postgres:15 - postgres:13 - postgres:11 include: - - dbimage: postgres:14 + - dbimage: postgres:15 dbauth: scram-sha-256 - dbimage: postgres:13 dbauth: md5 From 7daf026e145de2c07d6e37f4171b1acb4b5f22b1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 26 Oct 2022 15:11:06 +0200 Subject: [PATCH 122/292] Use NIOThrowingAsyncSequenceProducer (#317) --- .../PostgresNIO/New/Messages/DataRow.swift | 2 +- Sources/PostgresNIO/New/PSQLRowStream.swift | 74 ++- .../PostgresNIO/New/PostgresRowSequence.swift | 554 +----------------- .../New/PostgresRowSequenceTests.swift | 10 +- 4 files changed, 93 insertions(+), 547 deletions(-) diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index d0b078c7..4cdc92f8 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -117,6 +117,6 @@ extension DataRow { } } -#if swift(>=5.6) +#if swift(>=5.5) extension DataRow: Sendable {} #endif diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index c73cda20..c5a9cd3f 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -2,6 +2,8 @@ import NIOCore import Logging final class PSQLRowStream { + private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source + enum RowSource { case stream(PSQLRowsDataSource) case noRows(Result) @@ -23,7 +25,7 @@ final class PSQLRowStream { case consumed(Result) #if canImport(_Concurrency) - case asyncSequence(AsyncStreamConsumer, PSQLRowsDataSource) + case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource) #endif } @@ -71,26 +73,35 @@ final class PSQLRowStream { preconditionFailure("Invalid state: \(self.downstreamState)") } - let consumer = AsyncStreamConsumer( - lookupTable: self.lookupTable, - columns: self.rowDescription + let producer = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: DataRow.self, + failureType: Error.self, + backPressureStrategy: AdaptiveRowBuffer(), + delegate: self ) + + let source = producer.source switch bufferState { case .streaming(let bufferedRows, let dataSource): - consumer.startStreaming(bufferedRows, upstream: self) - self.downstreamState = .asyncSequence(consumer, dataSource) + let yieldResult = source.yield(contentsOf: bufferedRows) + self.downstreamState = .asyncSequence(source, dataSource) + + self.eventLoop.execute { + self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) + } case .finished(let buffer, let commandTag): - consumer.startCompleted(buffer, commandTag: commandTag) + _ = source.yield(contentsOf: buffer) + source.finish() self.downstreamState = .consumed(.success(commandTag)) case .failure(let error): - consumer.startFailed(error) + source.finish(error) self.downstreamState = .consumed(.failure(error)) } - return PostgresRowSequence(consumer) + return PostgresRowSequence(producer.sequence, lookupTable: self.lookupTable, columns: self.rowDescription) } func demand() { @@ -128,10 +139,8 @@ final class PSQLRowStream { private func cancel0() { switch self.downstreamState { - case .asyncSequence(let consumer, let dataSource): - let error = PSQLError.connectionClosed - self.downstreamState = .consumed(.failure(error)) - consumer.receive(completion: .failure(error)) + case .asyncSequence(_, let dataSource): + self.downstreamState = .consumed(.failure(CancellationError())) dataSource.cancel(for: self) case .consumed: @@ -305,8 +314,9 @@ final class PSQLRowStream { dataSource.request(for: self) #if canImport(_Concurrency) - case .asyncSequence(let consumer, _): - consumer.receive(newRows) + case .asyncSequence(let consumer, let source): + let yieldResult = consumer.yield(contentsOf: newRows) + self.executeActionBasedOnYieldResult(yieldResult, source: source) #endif case .consumed(.success): @@ -345,8 +355,8 @@ final class PSQLRowStream { promise.succeed(rows) #if canImport(_Concurrency) - case .asyncSequence(let consumer, _): - consumer.receive(completion: .success(commandTag)) + case .asyncSequence(let source, _): + source.finish() self.downstreamState = .consumed(.success(commandTag)) #endif @@ -373,7 +383,7 @@ final class PSQLRowStream { #if canImport(_Concurrency) case .asyncSequence(let consumer, _): - consumer.receive(completion: .failure(error)) + consumer.finish(error) self.downstreamState = .consumed(.failure(error)) #endif @@ -381,6 +391,22 @@ final class PSQLRowStream { break } } + + private func executeActionBasedOnYieldResult(_ yieldResult: AsyncSequenceSource.YieldResult, source: PSQLRowsDataSource) { + self.eventLoop.preconditionInEventLoop() + switch yieldResult { + case .dropped: + // ignore + break + + case .produceMore: + source.request(for: self) + + case .stopProducing: + // ignore + break + } + } var commandTag: String { guard case .consumed(.success(let commandTag)) = self.downstreamState else { @@ -390,6 +416,16 @@ final class PSQLRowStream { } } +extension PSQLRowStream: NIOAsyncSequenceProducerDelegate { + func produceMore() { + self.demand() + } + + func didTerminate() { + self.cancel() + } +} + protocol PSQLRowsDataSource { func request(for stream: PSQLRowStream) @@ -397,7 +433,7 @@ protocol PSQLRowsDataSource { } -#if swift(>=5.6) +#if swift(>=5.5) // Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop. extension PSQLRowStream: @unchecked Sendable {} #endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 5c18e43a..8248e14a 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -8,32 +8,24 @@ import NIOConcurrencyHelpers public struct PostgresRowSequence: AsyncSequence { public typealias Element = PostgresRow - final class _Internal { + typealias BackingSequence = NIOThrowingAsyncSequenceProducer - let consumer: AsyncStreamConsumer - - init(consumer: AsyncStreamConsumer) { - self.consumer = consumer - } - - deinit { - // if no iterator was created, we need to cancel the stream - self.consumer.sequenceDeinitialized() - } - - func makeAsyncIterator() -> AsyncIterator { - self.consumer.makeAsyncIterator() - } - } - - let _internal: _Internal + let backing: BackingSequence + let lookupTable: [String: Int] + let columns: [RowDescription.Column] - init(_ consumer: AsyncStreamConsumer) { - self._internal = .init(consumer: consumer) + init(_ backing: BackingSequence, lookupTable: [String: Int], columns: [RowDescription.Column]) { + self.backing = backing + self.lookupTable = lookupTable + self.columns = columns } public func makeAsyncIterator() -> AsyncIterator { - self._internal.makeAsyncIterator() + AsyncIterator( + backing: self.backing.makeAsyncIterator(), + lookupTable: self.lookupTable, + columns: self.columns + ) } } @@ -41,495 +33,27 @@ extension PostgresRowSequence { public struct AsyncIterator: AsyncIteratorProtocol { public typealias Element = PostgresRow - let _internal: _Internal + let backing: BackingSequence.AsyncIterator let lookupTable: [String: Int] let columns: [RowDescription.Column] - init(consumer: AsyncStreamConsumer) { - self._internal = _Internal(consumer: consumer) - self.lookupTable = consumer.lookupTable - self.columns = consumer.columns + init(backing: BackingSequence.AsyncIterator, lookupTable: [String: Int], columns: [RowDescription.Column]) { + self.backing = backing + self.lookupTable = lookupTable + self.columns = columns } public mutating func next() async throws -> PostgresRow? { - if let dataRow = try await self._internal.next() { + if let dataRow = try await self.backing.next() { return PostgresRow( data: dataRow, lookupTable: self.lookupTable, - columns: columns + columns: self.columns ) } return nil } - - final class _Internal { - let consumer: AsyncStreamConsumer - - init(consumer: AsyncStreamConsumer) { - self.consumer = consumer - } - - deinit { - self.consumer.iteratorDeinitialized() - } - - func next() async throws -> DataRow? { - try await self.consumer.next() - } - } - } -} - -final class AsyncStreamConsumer { - let lock = NIOLock() - - let lookupTable: [String: Int] - let columns: [RowDescription.Column] - private var state: StateMachine - - init( - lookupTable: [String: Int], - columns: [RowDescription.Column] - ) { - self.state = StateMachine() - - self.lookupTable = lookupTable - self.columns = columns - } - - func startCompleted(_ buffer: CircularBuffer, commandTag: String) { - self.lock.withLock { - self.state.finished(buffer, commandTag: commandTag) - } - } - - func startStreaming(_ buffer: CircularBuffer, upstream: PSQLRowStream) { - self.lock.withLock { - self.state.buffered(buffer, upstream: upstream) - } - } - - func startFailed(_ error: Error) { - self.lock.withLock { - self.state.failed(error) - } - } - - func receive(_ newRows: [DataRow]) { - let receiveAction = self.lock.withLock { - self.state.receive(newRows) - } - - switch receiveAction { - case .succeed(let continuation, let data, signalDemandTo: let source): - continuation.resume(returning: data) - source?.demand() - - case .none: - break - } - } - - func receive(completion result: Result) { - let completionAction = self.lock.withLock { - self.state.receive(completion: result) - } - - switch completionAction { - case .succeed(let continuation): - continuation.resume(returning: nil) - - case .fail(let continuation, let error): - continuation.resume(throwing: error) - - case .none: - break - } - } - - func sequenceDeinitialized() { - let action = self.lock.withLock { - self.state.sequenceDeinitialized() - } - - switch action { - case .cancelStream(let source): - source.cancel() - case .none: - break - } - } - - func makeAsyncIterator() -> PostgresRowSequence.AsyncIterator { - self.lock.withLock { - self.state.createAsyncIterator() - } - let iterator = PostgresRowSequence.AsyncIterator(consumer: self) - return iterator - } - - func iteratorDeinitialized() { - let action = self.lock.withLock { - self.state.iteratorDeinitialized() - } - - switch action { - case .cancelStream(let source): - source.cancel() - case .none: - break - } - } - - func next() async throws -> DataRow? { - self.lock.lock() - switch self.state.next() { - case .returnNil: - self.lock.unlock() - return nil - - case .returnRow(let data, signalDemandTo: let source): - self.lock.unlock() - source?.demand() - return data - - case .throwError(let error): - self.lock.unlock() - throw error - - case .hitSlowPath: - return try await withCheckedThrowingContinuation { continuation in - let slowPathAction = self.state.next(for: continuation) - self.lock.unlock() - switch slowPathAction { - case .signalDemand(let source): - source.demand() - case .none: - break - } - } - } - } - -} - -extension AsyncStreamConsumer { - private struct StateMachine { - private enum UpstreamState { - enum DemandState { - case canAskForMore - case waitingForMore(CheckedContinuation?) - } - - case initialized - /// The upstream has more data that can be received - case streaming(AdaptiveRowBuffer, PSQLRowStream, DemandState) - /// The upstream has finished, but the downstream has not consumed all events. - case finished(AdaptiveRowBuffer, String) - /// The upstream has failed, but the downstream has not consumed the error yet. - case failed(Error) - /// The upstream has failed or finished and the downstream has consumed all events. Final state. - case consumed - - /// A state used to prevent CoW allocations when modifying an internal struct in the - /// `.streaming` or `.finished` state. - case modifying - } - - private enum DownstreamState { - case sequenceCreated - case iteratorCreated - } - - private var upstreamState = UpstreamState.initialized - private var downstreamState = DownstreamState.sequenceCreated - - init() {} - - mutating func buffered(_ buffer: CircularBuffer, upstream: PSQLRowStream) { - switch self.upstreamState { - case .initialized: - let adaptive = AdaptiveRowBuffer(buffer) - self.upstreamState = .streaming(adaptive, upstream, buffer.isEmpty ? .waitingForMore(nil) : .canAskForMore) - - case .streaming, .finished, .failed, .consumed, .modifying: - preconditionFailure("Invalid upstream state: \(self.upstreamState)") - } - } - - mutating func finished(_ buffer: CircularBuffer, commandTag: String) { - switch self.upstreamState { - case .initialized: - let adaptive = AdaptiveRowBuffer(buffer) - self.upstreamState = .finished(adaptive, commandTag) - - case .streaming, .finished, .failed, .consumed, .modifying: - preconditionFailure("Invalid upstream state: \(self.upstreamState)") - } - } - - mutating func failed(_ error: Error) { - switch self.upstreamState { - case .initialized: - self.upstreamState = .failed(error) - - case .streaming, .finished, .failed, .consumed, .modifying: - preconditionFailure("Invalid upstream state: \(self.upstreamState)") - } - } - - mutating func createAsyncIterator() { - switch self.downstreamState { - case .sequenceCreated: - self.downstreamState = .iteratorCreated - case .iteratorCreated: - preconditionFailure("An iterator already exists") - } - } - - enum SequenceDeinitializedAction { - case cancelStream(PSQLRowStream) - case none - } - - mutating func sequenceDeinitialized() -> SequenceDeinitializedAction { - switch (self.downstreamState, self.upstreamState) { - case (.sequenceCreated, .initialized): - preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") - - case (.sequenceCreated, .streaming(_, let source, _)): - return .cancelStream(source) - - case (.sequenceCreated, .finished), - (.sequenceCreated, .consumed), - (.sequenceCreated, .failed): - return .none - - case (.iteratorCreated, _): - return .none - - case (_, .modifying): - preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") - } - } - - mutating func iteratorDeinitialized() -> SequenceDeinitializedAction { - switch (self.downstreamState, self.upstreamState) { - case (.sequenceCreated, _), - (.iteratorCreated, .initialized): - preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") - - case (.iteratorCreated, .streaming(_, let source, _)): - return .cancelStream(source) - - case (.iteratorCreated, .finished), - (.iteratorCreated, .consumed), - (.iteratorCreated, .failed): - return .none - - case (_, .modifying): - preconditionFailure("Invalid state: \(self.downstreamState), \(self.upstreamState)") - } - } - - enum NextFastPathAction { - case hitSlowPath - case throwError(Error) - case returnRow(DataRow, signalDemandTo: PSQLRowStream?) - case returnNil - } - - mutating func next() -> NextFastPathAction { - switch self.upstreamState { - case .initialized: - preconditionFailure() - - case .streaming(var buffer, let source, .canAskForMore): - self.upstreamState = .modifying - guard let (data, demand) = buffer.popFirst() else { - self.upstreamState = .streaming(buffer, source, .canAskForMore) - return .hitSlowPath - } - if demand { - self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) - return .returnRow(data, signalDemandTo: source) - } - self.upstreamState = .streaming(buffer, source, .canAskForMore) - return .returnRow(data, signalDemandTo: nil) - - case .streaming(var buffer, let source, .waitingForMore(.none)): - self.upstreamState = .modifying - guard let (data, _) = buffer.popFirst() else { - self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) - return .hitSlowPath - } - - self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) - return .returnRow(data, signalDemandTo: nil) - - case .streaming(_, _, .waitingForMore(.some)): - preconditionFailure() - - case .finished(var buffer, let commandTag): - self.upstreamState = .modifying - guard let (data, _) = buffer.popFirst() else { - self.upstreamState = .consumed - return .returnNil - } - - self.upstreamState = .finished(buffer, commandTag) - return .returnRow(data, signalDemandTo: nil) - - case .failed(let error): - self.upstreamState = .consumed - return .throwError(error) - - case .consumed: - return .returnNil - - case .modifying: - preconditionFailure("Invalid upstream state: \(self.upstreamState)") - } - } - - enum NextSlowPathAction { - case signalDemand(PSQLRowStream) - case none - } - - mutating func next(for continuation: CheckedContinuation) -> NextSlowPathAction { - switch self.upstreamState { - case .initialized: - preconditionFailure() - - case .streaming(let buffer, let source, .canAskForMore): - precondition(buffer.isEmpty) - self.upstreamState = .streaming(buffer, source, .waitingForMore(continuation)) - return .signalDemand(source) - - case .streaming(let buffer, let source, .waitingForMore(.none)): - precondition(buffer.isEmpty) - self.upstreamState = .streaming(buffer, source, .waitingForMore(continuation)) - return .none - - case .streaming(_, _, .waitingForMore(.some)), - .finished, - .failed, - .consumed: - preconditionFailure("Expected that state was already handled by fast path. Invalid upstream state: \(self.upstreamState)") - - case .modifying: - preconditionFailure("Invalid upstream state: \(self.upstreamState)") - } - } - - enum ReceiveAction { - case succeed(CheckedContinuation, DataRow, signalDemandTo: PSQLRowStream?) - case none - } - - mutating func receive(_ newRows: [DataRow]) -> ReceiveAction { - precondition(!newRows.isEmpty) - - switch self.upstreamState { - case .streaming(var buffer, let source, .waitingForMore(.some(let continuation))): - buffer.append(contentsOf: newRows) - let (first, demand) = buffer.removeFirst() - if demand { - self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) - return .succeed(continuation, first, signalDemandTo: source) - } - self.upstreamState = .streaming(buffer, source, .canAskForMore) - return .succeed(continuation, first, signalDemandTo: nil) - - case .streaming(var buffer, let source, .waitingForMore(.none)): - buffer.append(contentsOf: newRows) - self.upstreamState = .streaming(buffer, source, .canAskForMore) - return .none - - case .streaming(var buffer, let source, .canAskForMore): - buffer.append(contentsOf: newRows) - self.upstreamState = .streaming(buffer, source, .canAskForMore) - return .none - - case .initialized, .finished, .consumed: - preconditionFailure() - - case .failed: - return .none - - case .modifying: - preconditionFailure() - } - } - - enum CompletionResult { - case succeed(CheckedContinuation) - case fail(CheckedContinuation, Error) - case none - } - - mutating func receive(completion result: Result) -> CompletionResult { - switch result { - case .success(let commandTag): - return self.receiveEnd(commandTag: commandTag) - case .failure(let error): - return self.receiveError(error) - } - } - - private mutating func receiveEnd(commandTag: String) -> CompletionResult { - switch self.upstreamState { - case .streaming(let buffer, _, .waitingForMore(.some(let continuation))): - precondition(buffer.isEmpty) - self.upstreamState = .consumed - return .succeed(continuation) - - case .streaming(let buffer, _, .waitingForMore(.none)): - self.upstreamState = .finished(buffer, commandTag) - return .none - - case .streaming(let buffer, _, .canAskForMore): - self.upstreamState = .finished(buffer, commandTag) - return .none - - case .initialized, .finished, .consumed: - preconditionFailure("Invalid upstream state: \(self.upstreamState)") - - case .failed: - return .none - - case .modifying: - preconditionFailure() - } - } - - private mutating func receiveError(_ error: Error) -> CompletionResult { - switch self.upstreamState { - case .streaming(let buffer, _, .waitingForMore(.some(let continuation))): - precondition(buffer.isEmpty) - self.upstreamState = .consumed - return .fail(continuation, error) - - case .streaming(_, _, .waitingForMore(.none)): - self.upstreamState = .failed(error) - return .none - - case .streaming(_, _, .canAskForMore): - self.upstreamState = .failed(error) - return .none - - case .initialized, .finished, .consumed: - preconditionFailure("Invalid upstream state: \(self.upstreamState)") - - case .failed: - return .none - - case .modifying: - preconditionFailure() - } - } } } @@ -543,7 +67,7 @@ extension PostgresRowSequence { } } -struct AdaptiveRowBuffer { +struct AdaptiveRowBuffer: NIOAsyncSequenceProducerBackPressureStrategy { static let defaultBufferTarget = 256 static let defaultBufferMinimum = 1 static let defaultBufferMaximum = 16384 @@ -551,57 +75,41 @@ struct AdaptiveRowBuffer { let minimum: Int let maximum: Int - private var circularBuffer: CircularBuffer private var target: Int private var canShrink: Bool = false - var isEmpty: Bool { - self.circularBuffer.isEmpty - } - - init(minimum: Int, maximum: Int, target: Int, buffer: CircularBuffer) { + init(minimum: Int, maximum: Int, target: Int) { precondition(minimum <= target && target <= maximum) self.minimum = minimum self.maximum = maximum self.target = target - self.circularBuffer = buffer } - init(_ circularBuffer: CircularBuffer) { + init() { self.init( minimum: Self.defaultBufferMinimum, maximum: Self.defaultBufferMaximum, - target: Self.defaultBufferTarget, - buffer: circularBuffer + target: Self.defaultBufferTarget ) } - mutating func append(contentsOf newRows: Rows) where Rows.Element == DataRow { - self.circularBuffer.append(contentsOf: newRows) - if self.circularBuffer.count >= self.target, self.canShrink, self.target > self.minimum { + mutating func didYield(bufferDepth: Int) -> Bool { + if bufferDepth > self.target, self.canShrink, self.target > self.minimum { self.target &>>= 1 } self.canShrink = true - } - /// Returns the next row in the FIFO buffer and a `bool` signalling if new rows should be loaded. - mutating func removeFirst() -> (DataRow, Bool) { - let element = self.circularBuffer.removeFirst() + return false // bufferDepth < self.target + } + mutating func didConsume(bufferDepth: Int) -> Bool { // If the buffer is drained now, we should double our target size. - if self.circularBuffer.count == 0, self.target < self.maximum { + if bufferDepth == 0, self.target < self.maximum { self.target = self.target * 2 self.canShrink = false } - return (element, self.circularBuffer.count < self.target) - } - - mutating func popFirst() -> (DataRow, Bool)? { - guard !self.circularBuffer.isEmpty else { - return nil - } - return self.removeFirst() + return bufferDepth < self.target } } #endif diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index a8e20d76..54a8afc7 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -312,15 +312,16 @@ final class PostgresRowSequenceTests: XCTestCase { // received. let addDataRows1: [DataRow] = [[ByteBuffer(integer: Int64(0))]] stream.receive(addDataRows1) + XCTAssertEqual(dataSource.requestCount, 1) _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more XCTAssertEqual(dataSource.requestCount, 2) // if the buffer gets new rows so that it has equal or more than target (the target size // should be halved) - let addDataRows2: [DataRow] = [[ByteBuffer(integer: Int64(0))]] + let addDataRows2: [DataRow] = [[ByteBuffer(integer: Int64(0))], [ByteBuffer(integer: Int64(0))]] stream.receive(addDataRows2) // this should to target being halved. _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget/2 - 1) { + for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget / 2) { _ = try await rowIterator.next() // Remove all rows until we are back at target XCTAssertEqual(dataSource.requestCount, 2) } @@ -385,11 +386,12 @@ final class PostgresRowSequenceTests: XCTestCase { expectedRequestCount += 1 XCTAssertEqual(dataSource.requestCount, expectedRequestCount) - stream.receive([[ByteBuffer(integer: Int64(1))]]) + stream.receive([[ByteBuffer(integer: Int64(1))], [ByteBuffer(integer: Int64(1))]]) let newTarget = currentTarget / 2 + let toDrop = currentTarget + 1 - newTarget // consume all messages that are to much. - for _ in 0.. Date: Thu, 3 Nov 2022 11:43:36 +0100 Subject: [PATCH 123/292] Use NIOFoundationCompat for UUID <-> ByteBuffer (#319) --- Package.swift | 2 +- .../Data/PostgresData+String.swift | 2 +- .../PostgresNIO/Data/PostgresData+UUID.swift | 9 ++--- .../New/Data/UUID+PostgresCodable.swift | 35 ++----------------- 4 files changed, 6 insertions(+), 42 deletions(-) diff --git a/Package.swift b/Package.swift index 03ba1887..7e382068 100644 --- a/Package.swift +++ b/Package.swift @@ -14,7 +14,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.42.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.44.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.13.1"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.22.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), diff --git a/Sources/PostgresNIO/Data/PostgresData+String.swift b/Sources/PostgresNIO/Data/PostgresData+String.swift index 79d9d428..66a08337 100644 --- a/Sources/PostgresNIO/Data/PostgresData+String.swift +++ b/Sources/PostgresNIO/Data/PostgresData+String.swift @@ -22,7 +22,7 @@ extension PostgresData { case .numeric: return self.numeric?.string case .uuid: - return value.readUUID()!.uuidString + return value.readUUIDBytes()!.uuidString case .timestamp, .timestamptz, .date: return self.date?.description case .money: diff --git a/Sources/PostgresNIO/Data/PostgresData+UUID.swift b/Sources/PostgresNIO/Data/PostgresData+UUID.swift index 148a9e66..f899b345 100644 --- a/Sources/PostgresNIO/Data/PostgresData+UUID.swift +++ b/Sources/PostgresNIO/Data/PostgresData+UUID.swift @@ -4,12 +4,7 @@ import NIOCore extension PostgresData { public init(uuid: UUID) { var buffer = ByteBufferAllocator().buffer(capacity: 16) - buffer.writeBytes([ - uuid.uuid.0, uuid.uuid.1, uuid.uuid.2, uuid.uuid.3, - uuid.uuid.4, uuid.uuid.5, uuid.uuid.6, uuid.uuid.7, - uuid.uuid.8, uuid.uuid.9, uuid.uuid.10, uuid.uuid.11, - uuid.uuid.12, uuid.uuid.13, uuid.uuid.14, uuid.uuid.15, - ]) + buffer.writeUUIDBytes(uuid) self.init(type: .uuid, formatCode: .binary, value: buffer) } @@ -22,7 +17,7 @@ extension PostgresData { case .binary: switch self.type { case .uuid: - return value.readUUID() + return value.readUUIDBytes() case .varchar, .text: return self.string.flatMap { UUID(uuidString: $0) } default: diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index 3241ea01..be36395f 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -16,13 +16,7 @@ extension UUID: PostgresEncodable { into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext ) { - let uuid = self.uuid - byteBuffer.writeBytes([ - uuid.0, uuid.1, uuid.2, uuid.3, - uuid.4, uuid.5, uuid.6, uuid.7, - uuid.8, uuid.9, uuid.10, uuid.11, - uuid.12, uuid.13, uuid.14, uuid.15, - ]) + byteBuffer.writeUUIDBytes(self) } } @@ -36,7 +30,7 @@ extension UUID: PostgresDecodable { ) throws { switch (format, type) { case (.binary, .uuid): - guard let uuid = buffer.readUUID() else { + guard let uuid = buffer.readUUIDBytes() else { throw PostgresDecodingError.Code.failure } self = uuid @@ -60,28 +54,3 @@ extension UUID: PostgresDecodable { } extension UUID: PostgresCodable {} - -extension ByteBuffer { - @usableFromInline - mutating func readUUID() -> UUID? { - guard self.readableBytes >= MemoryLayout.size else { - return nil - } - - let value: UUID = self.getUUID(at: self.readerIndex)! /* must work as we have enough bytes */ - // should be MoveReaderIndex - self.moveReaderIndex(forwardBy: MemoryLayout.size) - return value - } - - func getUUID(at index: Int) -> UUID? { - var uuid: uuid_t = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - return self.viewBytes(at: index, length: MemoryLayout.size(ofValue: uuid)).map { bufferBytes in - withUnsafeMutableBytes(of: &uuid) { target in - precondition(target.count <= bufferBytes.count) - target.copyBytes(from: bufferBytes) - } - return UUID(uuid: uuid) - } - } -} From 607279d6893dd0332263489cc725255bfcfe520f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 10 Nov 2022 14:54:15 +0100 Subject: [PATCH 124/292] PostgresCodable should be a typealias (#321) --- .../PostgresNIO/New/Data/Bool+PostgresCodable.swift | 2 -- .../PostgresNIO/New/Data/Bytes+PostgresCodable.swift | 4 ---- .../PostgresNIO/New/Data/Date+PostgresCodable.swift | 2 -- .../PostgresNIO/New/Data/Decimal+PostgresCodable.swift | 2 -- .../PostgresNIO/New/Data/Float+PostgresCodable.swift | 4 ---- Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift | 10 ---------- .../PostgresNIO/New/Data/JSON+PostgresCodable.swift | 2 -- .../New/Data/RawRepresentable+PostgresCodable.swift | 2 -- .../PostgresNIO/New/Data/String+PostgresCodable.swift | 2 -- .../PostgresNIO/New/Data/UUID+PostgresCodable.swift | 2 -- Sources/PostgresNIO/New/PostgresCodable.swift | 2 +- 11 files changed, 1 insertion(+), 33 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift index 1aa264b8..3148a726 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -60,5 +60,3 @@ extension Bool: PostgresEncodable { byteBuffer.writeInteger(self ? 1 : 0, as: UInt8.self) } } - -extension Bool: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift index edf79462..fcd70472 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift @@ -51,8 +51,6 @@ extension ByteBuffer: PostgresDecodable { } } -extension ByteBuffer: PostgresCodable {} - extension Data: PostgresEncodable { public static var psqlType: PostgresDataType { .bytea @@ -82,5 +80,3 @@ extension Data: PostgresDecodable { self = buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! } } - -extension Data: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift index e32ecb10..b915fcb3 100644 --- a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -57,5 +57,3 @@ extension Date: PostgresDecodable { } } } - -extension Date: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift index 4ab96386..f634d4ae 100644 --- a/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Decimal+PostgresCodable.swift @@ -47,5 +47,3 @@ extension Decimal: PostgresDecodable { } } } - -extension Decimal: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift index 7943c152..70636772 100644 --- a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -48,8 +48,6 @@ extension Float: PostgresDecodable { } } -extension Float: PostgresCodable {} - extension Double: PostgresEncodable { public static var psqlType: PostgresDataType { .float8 @@ -97,5 +95,3 @@ extension Double: PostgresDecodable { } } } - -extension Double: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift index e4a2492d..d8335ff1 100644 --- a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -41,8 +41,6 @@ extension UInt8: PostgresDecodable { } } -extension UInt8: PostgresCodable {} - // MARK: Int16 extension Int16: PostgresEncodable { @@ -88,8 +86,6 @@ extension Int16: PostgresDecodable { } } -extension Int16: PostgresCodable {} - // MARK: Int32 extension Int32: PostgresEncodable { @@ -140,8 +136,6 @@ extension Int32: PostgresDecodable { } } -extension Int32: PostgresCodable {} - // MARK: Int64 extension Int64: PostgresEncodable { @@ -197,8 +191,6 @@ extension Int64: PostgresDecodable { } } -extension Int64: PostgresCodable {} - // MARK: Int extension Int: PostgresEncodable { @@ -260,5 +252,3 @@ extension Int: PostgresDecodable { } } } - -extension Int: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift index d5696bf2..539cd9e2 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift @@ -45,5 +45,3 @@ extension PostgresDecodable where Self: Decodable { } } } - -extension PostgresCodable where Self: Codable {} diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift index 4c0195e3..4d6c20c4 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -33,5 +33,3 @@ extension PostgresDecodable where Self: RawRepresentable, RawValue: PostgresDeco self = selfValue } } - -extension PostgresCodable where Self: RawRepresentable, RawValue: PostgresCodable, RawValue._DecodableType == RawValue {} diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 8efb8155..e262b343 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -45,5 +45,3 @@ extension String: PostgresDecodable { } } } - -extension String: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index be36395f..cb65c5ce 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -52,5 +52,3 @@ extension UUID: PostgresDecodable { } } } - -extension UUID: PostgresCodable {} diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 6a40b4bf..180e9bbf 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -68,7 +68,7 @@ extension PostgresDecodable { } /// A type that can be encoded into and decoded from a postgres binary format -protocol PostgresCodable: PostgresEncodable, PostgresDecodable {} +typealias PostgresCodable = PostgresEncodable & PostgresDecodable extension PostgresEncodable { @inlinable From 3a16650354aff072fc6ebfe5345a780f153a224b Mon Sep 17 00:00:00 2001 From: Tim Condon <0xTim@users.noreply.github.com> Date: Thu, 17 Nov 2022 18:31:34 +0000 Subject: [PATCH 125/292] Update SPI info to point to our hosted docs (#325) --- .spi.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.spi.yml b/.spi.yml index 76fd1534..177f9874 100644 --- a/.spi.yml +++ b/.spi.yml @@ -1,6 +1,4 @@ version: 1 -builder: - configs: - - documentation_targets: - - PostgresNIO +external_links: + documentation: "https://api.vapor.codes/postgres-nio/documentation/postgresnio/" From b5cca7227d328cf59037426641c5bf3053f366fe Mon Sep 17 00:00:00 2001 From: ehpi <16744346+ehpi@users.noreply.github.com> Date: Mon, 2 Jan 2023 10:57:45 +0100 Subject: [PATCH 126/292] PostgresData: Fix en- and decoding of NULL values in arrays (#324) --- .../PostgresNIO/Data/PostgresData+Array.swift | 12 ++++---- Tests/IntegrationTests/PostgresNIOTests.swift | 29 +++++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresData+Array.swift b/Sources/PostgresNIO/Data/PostgresData+Array.swift index bbb420bc..d0c1c6f4 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Array.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Array.swift @@ -13,7 +13,7 @@ extension PostgresData { var buffer = ByteBufferAllocator().buffer(capacity: 0) // 0 if empty, 1 if not buffer.writeInteger(array.isEmpty ? 0 : 1, as: UInt32.self) - // b + // b - this gets ignored by psql buffer.writeInteger(0, as: UInt32.self) // array element type buffer.writeInteger(elementType.rawValue) @@ -30,7 +30,7 @@ extension PostgresData { buffer.writeInteger(numericCast(value.readableBytes), as: UInt32.self) buffer.writeBuffer(&value) } else { - buffer.writeInteger(0, as: UInt32.self) + buffer.writeInteger(-1, as: Int32.self) } } } @@ -77,10 +77,10 @@ extension PostgresData { guard let isNotEmpty = value.readInteger(as: UInt32.self) else { return nil } - guard let b = value.readInteger(as: UInt32.self) else { + // b + guard let _ = value.readInteger(as: UInt32.self) else { return nil } - assert(b == 0, "Array b field did not equal zero") guard let type = value.readInteger(as: PostgresDataType.self) else { return nil } @@ -99,9 +99,9 @@ extension PostgresData { var array: [PostgresData] = [] while - let itemLength = value.readInteger(as: UInt32.self), - let itemValue = value.readSlice(length: numericCast(itemLength)) + let itemLength = value.readInteger(as: Int32.self) { + let itemValue = itemLength == -1 ? nil : value.readSlice(length: numericCast(itemLength)) let data = PostgresData( type: type, typeModifier: nil, diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index b455fdef..a56c4551 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -555,6 +555,19 @@ final class PostgresNIOTests: XCTestCase { let row = rows?.first?.makeRandomAccess() XCTAssertEqual(row?[data: "array"].array(of: Int.self), []) } + + func testOptionalIntegerArrayParse() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + '{1, 2, NULL, 4}'::int8[] as array + """).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Int?.self), [1, 2, nil, 4]) + } func testNullIntegerArrayParse() { var conn: PostgresConnection? @@ -599,6 +612,22 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "array"].array(of: Int.self), []) } + func testOptionalIntegerArraySerialize() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + $1::int8[] as array + """, [ + PostgresData(array: [1, nil, 3] as [Int64?]) + ]).wait()) + XCTAssertEqual(rows?.count, 1) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Int64?.self), [1, nil, 3]) + } + // https://github.com/vapor/postgres-nio/issues/143 func testEmptyStringFromNonNullColumn() { var conn: PostgresConnection? From ffb5121bc9e8c16d080ac4daf62d049dab261803 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 21 Jan 2023 15:46:17 +0100 Subject: [PATCH 127/292] Add new protocol PostgresNonThrowingEncodable (#322) --- .../New/Data/Array+PostgresCodable.swift | 37 +++++++++++++++++++ .../New/Data/Bool+PostgresCodable.swift | 2 +- .../New/Data/Bytes+PostgresCodable.swift | 6 ++- .../New/Data/Date+PostgresCodable.swift | 2 +- .../New/Data/Float+PostgresCodable.swift | 4 +- .../New/Data/Int+PostgresCodable.swift | 10 ++--- .../New/Data/String+PostgresCodable.swift | 2 +- .../New/Data/UUID+PostgresCodable.swift | 2 +- Sources/PostgresNIO/New/PostgresCodable.swift | 32 ++++++++++++++++ Sources/PostgresNIO/New/PostgresQuery.swift | 27 ++++++++++++++ Tests/IntegrationTests/AsyncTests.swift | 2 +- .../ExtendedQueryStateMachineTests.swift | 4 +- .../New/Data/Array+PSQLCodableTests.swift | 4 +- .../New/Messages/BindTests.swift | 4 +- .../New/PostgresQueryTests.swift | 16 +++----- 15 files changed, 124 insertions(+), 30 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index c3bf3eb4..2c57b605 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -124,6 +124,43 @@ extension Array: PostgresEncodable where Element: PostgresArrayEncodable { } } +extension Array: PostgresNonThrowingEncodable where Element: PostgresArrayEncodable & PostgresNonThrowingEncodable { + public static var psqlType: PostgresDataType { + Element.psqlArrayType + } + + public static var psqlFormat: PostgresFormat { + .binary + } + + @inlinable + public func encode( + into buffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + // 0 if empty, 1 if not + buffer.writeInteger(self.isEmpty ? 0 : 1, as: UInt32.self) + // b + buffer.writeInteger(0, as: Int32.self) + // array element type + buffer.writeInteger(Element.psqlType.rawValue) + + // continue if the array is not empty + guard !self.isEmpty else { + return + } + + // length of array + buffer.writeInteger(numericCast(self.count), as: Int32.self) + // dimensions + buffer.writeInteger(1, as: Int32.self) + + self.forEach { element in + element.encodeRaw(into: &buffer, context: context) + } + } +} + extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Element == Element._DecodableType { public init( diff --git a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift index 3148a726..515d167a 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PostgresCodable.swift @@ -43,7 +43,7 @@ extension Bool: PostgresDecodable { } } -extension Bool: PostgresEncodable { +extension Bool: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .bool } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift index fcd70472..f6544df0 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PostgresCodable.swift @@ -6,7 +6,7 @@ extension PostgresEncodable where Self: Sequence, Self.Element == UInt8 { public static var psqlType: PostgresDataType { .bytea } - + public static var psqlFormat: PostgresFormat { .binary } @@ -20,7 +20,9 @@ extension PostgresEncodable where Self: Sequence, Self.Element == UInt8 { } } -extension ByteBuffer: PostgresEncodable { +extension PostgresNonThrowingEncodable where Self: Sequence, Self.Element == UInt8 {} + +extension ByteBuffer: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .bytea } diff --git a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift index b915fcb3..31d8d749 100644 --- a/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PostgresCodable.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.Date -extension Date: PostgresEncodable { +extension Date: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .timestamptz } diff --git a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift index 70636772..8b5e4472 100644 --- a/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PostgresCodable.swift @@ -1,6 +1,6 @@ import NIOCore -extension Float: PostgresEncodable { +extension Float: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .float4 } @@ -48,7 +48,7 @@ extension Float: PostgresDecodable { } } -extension Double: PostgresEncodable { +extension Double: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .float8 } diff --git a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift index d8335ff1..c2f3b339 100644 --- a/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PostgresCodable.swift @@ -2,7 +2,7 @@ import NIOCore // MARK: UInt8 -extension UInt8: PostgresEncodable { +extension UInt8: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .char } @@ -43,7 +43,7 @@ extension UInt8: PostgresDecodable { // MARK: Int16 -extension Int16: PostgresEncodable { +extension Int16: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .int2 } @@ -88,7 +88,7 @@ extension Int16: PostgresDecodable { // MARK: Int32 -extension Int32: PostgresEncodable { +extension Int32: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .int4 } @@ -138,7 +138,7 @@ extension Int32: PostgresDecodable { // MARK: Int64 -extension Int64: PostgresEncodable { +extension Int64: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .int8 } @@ -193,7 +193,7 @@ extension Int64: PostgresDecodable { // MARK: Int -extension Int: PostgresEncodable { +extension Int: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { switch MemoryLayout.size { case 4: diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index e262b343..f8e93e94 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.UUID -extension String: PostgresEncodable { +extension String: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .text } diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index cb65c5ce..632d5d93 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -2,7 +2,7 @@ import NIOCore import struct Foundation.UUID import typealias Foundation.uuid_t -extension UUID: PostgresEncodable { +extension UUID: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { .uuid } diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 180e9bbf..bd4e7f91 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -3,6 +3,8 @@ import Foundation /// A type that can encode itself to a postgres wire binary representation. public protocol PostgresEncodable { + // TODO: Rename to `PostgresThrowingEncodable` with next major release + /// identifies the data type that we will encode into `byteBuffer` in `encode` static var psqlType: PostgresDataType { get } @@ -14,6 +16,16 @@ public protocol PostgresEncodable { func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) throws } +/// A type that can encode itself to a postgres wire binary representation. It enforces that the +/// ``PostgresEncodable/encode(into:context:)`` does not throw. This allows users +/// to create ``PostgresQuery``s using the `ExpressibleByStringInterpolation` without +/// having to spell `try`. +public protocol PostgresNonThrowingEncodable: PostgresEncodable { + // TODO: Rename to `PostgresEncodable` with next major release + + func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) +} + /// A type that can decode itself from a postgres wire binary representation. /// /// If you want to conform a type to PostgresDecodable you must implement the decode method. @@ -90,6 +102,26 @@ extension PostgresEncodable { } } +extension PostgresNonThrowingEncodable { + @inlinable + func encodeRaw( + into buffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + // The length of the parameter value, in bytes (this count does not include + // itself). Can be zero. + let lengthIndex = buffer.writerIndex + buffer.writeInteger(0, as: Int32.self) + let startIndex = buffer.writerIndex + // The value of the parameter, in the format indicated by the associated format + // code. n is the above length. + self.encode(into: &buffer, context: context) + + // overwrite the empty length, with the real value + buffer.setInteger(numericCast(buffer.writerIndex - startIndex), at: lengthIndex, as: Int32.self) + } +} + /// A context that is passed to Swift objects that are encoded into the Postgres wire format. Used /// to pass further information to the encoding method. public struct PostgresEncodingContext { diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 276d969f..94072ae3 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -59,6 +59,24 @@ extension PostgresQuery { self.sql.append(contentsOf: "$\(self.binds.count)") } + @inlinable + public mutating func appendInterpolation(_ value: Value) { + self.binds.append(value, context: .default) + self.sql.append(contentsOf: "$\(self.binds.count)") + } + + @inlinable + public mutating func appendInterpolation(_ value: Optional) { + switch value { + case .none: + self.binds.appendNull() + case .some(let value): + self.binds.append(value, context: .default) + } + + self.sql.append(contentsOf: "$\(self.binds.count)") + } + @inlinable public mutating func appendInterpolation( _ value: Value, @@ -139,6 +157,15 @@ public struct PostgresBindings: Hashable { self.metadata.append(.init(value: value)) } + @inlinable + public mutating func append( + _ value: Value, + context: PostgresEncodingContext + ) { + value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value)) + } + mutating func append(_ postgresData: PostgresData) { switch postgresData.value { case .none: diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 9d43397f..00896a91 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -90,7 +90,7 @@ final class AsyncPostgresConnectionTests: XCTestCase { var binds = PostgresBindings(capacity: Int(UInt16.max)) for _ in (0.. Date: Sat, 4 Feb 2023 00:58:05 +0000 Subject: [PATCH 128/292] Migrate API docs to new workflow --- .github/workflows/api-docs.yml | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/.github/workflows/api-docs.yml b/.github/workflows/api-docs.yml index d521498e..5d8c32dd 100644 --- a/.github/workflows/api-docs.yml +++ b/.github/workflows/api-docs.yml @@ -1,18 +1,12 @@ name: deploy-api-docs on: - push: - branches: - - main + push: + branches: + - main jobs: - deploy: - name: api.vapor.codes - runs-on: ubuntu-latest - steps: - - name: Deploy api-docs - uses: appleboy/ssh-action@master - with: - host: vapor.codes - username: vapor - key: ${{ secrets.VAPOR_CODES_SSH_KEY }} - script: ./github-actions/deploy-api-docs.sh + build-and-deploy: + uses: vapor/api-docs/.github/workflows/build-and-deploy-docs-workflow.yml@main + with: + package_name: postgres-nio + modules: PostgresNIO \ No newline at end of file From 7cfd33c73ecc186c49fc476aad82bebb3d5529be Mon Sep 17 00:00:00 2001 From: Tim Condon <0xTim@users.noreply.github.com> Date: Sun, 5 Feb 2023 19:57:01 +0000 Subject: [PATCH 129/292] Use newer docc github action --- .github/workflows/api-docs.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/api-docs.yml b/.github/workflows/api-docs.yml index 5d8c32dd..29e73a82 100644 --- a/.github/workflows/api-docs.yml +++ b/.github/workflows/api-docs.yml @@ -6,7 +6,8 @@ on: jobs: build-and-deploy: - uses: vapor/api-docs/.github/workflows/build-and-deploy-docs-workflow.yml@main + uses: vapor/api-docs/.github/workflows/build-and-deploy-docs-workflow.yml@8d28281fe89fd836116d59c7fe217df651ebf41a + secrets: inherit with: package_name: postgres-nio - modules: PostgresNIO \ No newline at end of file + modules: PostgresNIO From 521b6b4ca8027ad6ee6f5474b6ecaf23dad84762 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 14 Feb 2023 11:03:38 +0100 Subject: [PATCH 130/292] `PostgresQuery` and `PostgresBindings` should be `Sendable` (#328) --- Sources/PostgresNIO/New/PostgresQuery.swift | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 94072ae3..bbacf5c3 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -177,3 +177,9 @@ public struct PostgresBindings: Hashable { self.metadata.append(.init(dataType: postgresData.type, format: .binary)) } } + +#if swift(>=5.6) +extension PostgresQuery: Sendable {} +extension PostgresBindings: Sendable {} +extension PostgresBindings.Metadata: Sendable {} +#endif From 7a816db082008b7e4c0f1000ae0e827ac5d970e5 Mon Sep 17 00:00:00 2001 From: Tim Condon <0xTim@users.noreply.github.com> Date: Mon, 20 Feb 2023 08:11:03 +0000 Subject: [PATCH 131/292] Update API Docs Workflow (#329) --- .github/workflows/api-docs.yml | 3 ++- .github/workflows/test.yml | 10 ++++++++++ .spi.yml | 2 +- .../PostgresNIO/New/PSQLFrontendMessageEncoder.swift | 1 + .../New/PostgresBackendMessageDecoder.swift | 2 ++ Sources/PostgresNIO/New/PostgresQuery.swift | 2 ++ Sources/PostgresNIO/Utilities/Exports.swift | 4 ++++ .../PostgresNIO/Utilities/PostgresJSONDecoder.swift | 1 + .../PostgresNIO/Utilities/PostgresJSONEncoder.swift | 1 + Tests/IntegrationTests/AsyncTests.swift | 2 ++ Tests/IntegrationTests/PostgresNIOTests.swift | 1 + .../Message/PostgresMessageDecoderTests.swift | 1 + .../New/Extensions/PSQLFrontendMessageDecoder.swift | 1 + Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift | 2 ++ Tests/PostgresNIOTests/New/PostgresCellTests.swift | 1 + Tests/PostgresNIOTests/New/PostgresCodableTests.swift | 1 + Tests/PostgresNIOTests/New/PostgresErrorTests.swift | 1 + Tests/PostgresNIOTests/New/PostgresQueryTests.swift | 1 + .../New/PostgresRowSequenceTests.swift | 2 ++ Tests/PostgresNIOTests/New/PostgresRowTests.swift | 1 + 20 files changed, 38 insertions(+), 2 deletions(-) diff --git a/.github/workflows/api-docs.yml b/.github/workflows/api-docs.yml index 29e73a82..80291c6f 100644 --- a/.github/workflows/api-docs.yml +++ b/.github/workflows/api-docs.yml @@ -6,8 +6,9 @@ on: jobs: build-and-deploy: - uses: vapor/api-docs/.github/workflows/build-and-deploy-docs-workflow.yml@8d28281fe89fd836116d59c7fe217df651ebf41a + uses: vapor/api-docs/.github/workflows/build-and-deploy-docs-workflow.yml@main secrets: inherit with: package_name: postgres-nio modules: PostgresNIO + pathsToInvalidate: /postgresnio diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c1d82648..e83acc54 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -164,3 +164,13 @@ jobs: - name: API breaking changes run: | swift package diagnose-api-breaking-changes origin/main + test-exports: + name: Test exports + runs-on: ubuntu-latest + steps: + - name: Check out package + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Build + run: swift build -Xswiftc -DBUILDING_DOCC \ No newline at end of file diff --git a/.spi.yml b/.spi.yml index 177f9874..690e4f2a 100644 --- a/.spi.yml +++ b/.spi.yml @@ -1,4 +1,4 @@ version: 1 external_links: - documentation: "https://api.vapor.codes/postgres-nio/documentation/postgresnio/" + documentation: "https://api.vapor.codes/postgresnio/documentation/postgresnio/" diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift index 8447c683..24155d84 100644 --- a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift @@ -1,3 +1,4 @@ +import NIOCore struct PSQLFrontendMessageEncoder: MessageToByteEncoder { typealias OutboundIn = PostgresFrontendMessage diff --git a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift index e8487fb6..076daa19 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -1,3 +1,5 @@ +import NIOCore + struct PostgresBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { typealias InboundOut = PostgresBackendMessage diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index bbacf5c3..9aa93d3b 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -1,3 +1,5 @@ +import NIOCore + /// A Postgres SQL query, that can be executed on a Postgres server. Contains the raw sql string and bindings. public struct PostgresQuery: Hashable { /// The query string diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 4224d53f..1c020411 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,4 +1,8 @@ +#if !BUILDING_DOCC + // TODO: Remove this with the next major release! @_exported import NIO @_exported import NIOSSL @_exported import struct Logging.Logger + +#endif diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift index 5a87a182..fb7b4e8d 100644 --- a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift +++ b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift @@ -1,6 +1,7 @@ import class Foundation.JSONDecoder import struct Foundation.Data import NIOFoundationCompat +import NIOCore /// A protocol that mimicks the Foundation `JSONDecoder.decode(_:from:)` function. /// Conform a non-Foundation JSON decoder to this protocol if you want PostgresNIO to be diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift index 3cabcf1d..735e4b14 100644 --- a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift +++ b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift @@ -1,5 +1,6 @@ import Foundation import NIOFoundationCompat +import NIOCore /// A protocol that mimicks the Foundation `JSONEncoder.encode(_:)` function. /// Conform a non-Foundation JSON encoder to this protocol if you want PostgresNIO to be diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 00896a91..b1a72e5f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -4,6 +4,8 @@ import PostgresNIO #if canImport(Network) import NIOTransportServices #endif +import NIOPosix +import NIOCore #if canImport(_Concurrency) final class AsyncPostgresConnectionTests: XCTestCase { diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index a56c4551..8c84e280 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -4,6 +4,7 @@ import XCTest import NIOCore import NIOPosix import NIOTestUtils +import NIOSSL final class PostgresNIOTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift b/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift index d4557a55..bbd022db 100644 --- a/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift +++ b/Tests/PostgresNIOTests/Message/PostgresMessageDecoderTests.swift @@ -1,6 +1,7 @@ import PostgresNIO import XCTest import NIOTestUtils +import NIOCore class PostgresMessageDecoderTests: XCTestCase { @available(*, deprecated, message: "Tests deprecated API") diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 047a2968..91471d86 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -1,4 +1,5 @@ @testable import PostgresNIO +import NIOCore struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { typealias InboundOut = PostgresFrontendMessage diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 5ca43591..f27ff060 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -2,6 +2,8 @@ import NIOCore import Logging import XCTest @testable import PostgresNIO +import NIOCore +import NIOEmbedded class PSQLRowStreamTests: XCTestCase { func testEmptyStream() { diff --git a/Tests/PostgresNIOTests/New/PostgresCellTests.swift b/Tests/PostgresNIOTests/New/PostgresCellTests.swift index df7cbfd9..7df5ac9f 100644 --- a/Tests/PostgresNIOTests/New/PostgresCellTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCellTests.swift @@ -1,5 +1,6 @@ @testable import PostgresNIO import XCTest +import NIOCore final class PostgresCellTests: XCTestCase { func testDecodingANonOptionalString() { diff --git a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift index ef76e22a..c1ef041e 100644 --- a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift @@ -1,5 +1,6 @@ import XCTest @testable import PostgresNIO +import NIOCore final class PostgresCodableTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift index a3f44980..b1b78ff9 100644 --- a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift @@ -1,5 +1,6 @@ @testable import PostgresNIO import XCTest +import NIOCore final class PostgresDecodingErrorTests: XCTestCase { func testPostgresDecodingErrorEquality() { diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift index 832db148..926541f0 100644 --- a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -1,5 +1,6 @@ @testable import PostgresNIO import XCTest +import NIOCore final class PostgresQueryTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 54a8afc7..5cd69662 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -3,6 +3,8 @@ import NIOEmbedded import Dispatch import XCTest @testable import PostgresNIO +import NIOCore +import Logging #if canImport(_Concurrency) final class PostgresRowSequenceTests: XCTestCase { diff --git a/Tests/PostgresNIOTests/New/PostgresRowTests.swift b/Tests/PostgresNIOTests/New/PostgresRowTests.swift index 7a67823b..c84b9baa 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowTests.swift @@ -1,5 +1,6 @@ import XCTest @testable import PostgresNIO +import NIOCore final class PostgresRowTests: XCTestCase { From 5d93f3e05f0493441ad46f6d7a76109ff685329c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9my=20Touzy?= Date: Sun, 12 Mar 2023 14:06:32 +0100 Subject: [PATCH 132/292] Make Decodable autoconformance to PostgresDecodable public (#331) --- Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift index 539cd9e2..e469f0e5 100644 --- a/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/JSON+PostgresCodable.swift @@ -26,7 +26,7 @@ extension PostgresEncodable where Self: Encodable { } extension PostgresDecodable where Self: Decodable { - init( + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, From 4097b2f7a164f6c9572303f1d53032a92711480a Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 22 Mar 2023 09:27:13 +0100 Subject: [PATCH 133/292] Make `PostgresBindings.append(PostgresData)` public (#332) --- Sources/PostgresNIO/New/PostgresQuery.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 9aa93d3b..6f224895 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -168,7 +168,7 @@ public struct PostgresBindings: Hashable { self.metadata.append(.init(value: value)) } - mutating func append(_ postgresData: PostgresData) { + public mutating func append(_ postgresData: PostgresData) { switch postgresData.value { case .none: self.bytes.writeInteger(-1, as: Int32.self) From cf09800cfc59a7fce17cbf9c699c43ed5a405ea9 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 14 Apr 2023 04:55:14 -0500 Subject: [PATCH 134/292] Update CI in preparation for bumping to 5.6 min version (#337) Update CI in preparation for bumping to 5.6 min version. --- .github/workflows/test.yml | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e83acc54..5945d014 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,7 @@ name: CI +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true on: push: branches: @@ -13,9 +16,10 @@ jobs: fail-fast: false matrix: container: - - swift:5.5-bionic - swift:5.6-focal - swift:5.7-jammy + - swift:5.8-jammy + - swiftlang/swift:nightly-5.9-jammy - swiftlang/swift:nightly-main-jammy container: ${{ matrix.container }} runs-on: ubuntu-latest @@ -27,7 +31,7 @@ jobs: - name: Run unit tests with code coverage and Thread Sanitizer run: swift test --enable-test-discovery --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage - name: Submit coverage report to Codecov.io - if: "!contains(matrix.container, 'nightly')" + if: ${{ !contains(matrix.container, '5.8') }} uses: vapor/swift-codecov-action@v0.2 with: cc_flags: 'unittests' @@ -52,7 +56,8 @@ jobs: dbauth: md5 - dbimage: postgres:11 dbauth: trust - container: swift:5.7-jammy + container: + image: swift:5.8-jammy runs-on: ubuntu-latest env: LOG_LEVEL: debug @@ -114,14 +119,13 @@ jobs: fail-fast: false matrix: dbimage: - # Only test the lastest version on macOS, let Linux do the rest + # Only test one version on macOS, let Linux do the rest - postgresql@14 dbauth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 xcode: - latest-stable - #- latest runs-on: macos-12 env: LOG_LEVEL: debug @@ -145,14 +149,12 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - name: Run all tests - run: | - swift test --enable-test-discovery -Xlinker -rpath \ - -Xlinker $(xcode-select -p)/Toolchains/XcodeDefault.xctoolchain/usr/lib/swift-5.5/macosx + run: swift test api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest - container: swift:5.7-jammy + container: swift:5.8-jammy steps: - name: Checkout uses: actions/checkout@v3 @@ -162,15 +164,16 @@ jobs: - name: Mark the workspace as safe run: git config --global --add safe.directory ${GITHUB_WORKSPACE} - name: API breaking changes - run: | - swift package diagnose-api-breaking-changes origin/main + run: swift package diagnose-api-breaking-changes origin/main + test-exports: name: Test exports runs-on: ubuntu-latest + container: swift:5.8-jammy steps: - name: Check out package uses: actions/checkout@v3 with: fetch-depth: 0 - name: Build - run: swift build -Xswiftc -DBUILDING_DOCC \ No newline at end of file + run: swift build -Xswiftc -DBUILDING_DOCC From cf62abcf023cef8fcab3ef786c6f4a2fdef0d936 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 14 Apr 2023 12:10:28 +0200 Subject: [PATCH 135/292] Drop support for Swift 5.5 (#336) --- Package.swift | 2 +- .../Connection/PostgresConnection.swift | 10 ++----- Sources/PostgresNIO/Data/PostgresData.swift | 6 +--- .../PostgresNIO/Data/PostgresDataType.swift | 13 ++------- Sources/PostgresNIO/Data/PostgresRow.swift | 10 ++----- .../PostgresNIO/New/Messages/DataRow.swift | 6 +--- .../New/Messages/RowDescription.swift | 6 +--- Sources/PostgresNIO/New/PSQLRowStream.swift | 29 +++++-------------- Sources/PostgresNIO/New/PostgresCell.swift | 6 +--- Sources/PostgresNIO/New/PostgresQuery.swift | 12 ++------ .../PostgresRowSequence-multi-decode.swift | 2 -- .../PostgresNIO/New/PostgresRowSequence.swift | 2 -- Tests/IntegrationTests/AsyncTests.swift | 2 -- .../New/PostgresRowSequenceTests.swift | 2 -- 14 files changed, 22 insertions(+), 86 deletions(-) diff --git a/Package.swift b/Package.swift index 7e382068..ea9c1c6b 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.5 +// swift-tools-version:5.6 import PackageDescription let package = Package( diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index ac533c6e..d98d2f17 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -8,7 +8,9 @@ import Logging import NIOPosix /// A Postgres connection. Use it to run queries against a Postgres server. -public final class PostgresConnection { +/// +/// Thread safety is achieved by dispatching all access to shared state onto the underlying EventLoop. +public final class PostgresConnection: @unchecked Sendable { /// A Postgres connection ID public typealias ID = Int @@ -449,7 +451,6 @@ extension PostgresConnection { // MARK: Async/Await Interface -#if canImport(_Concurrency) extension PostgresConnection { /// Creates a new connection to a Postgres server. @@ -513,7 +514,6 @@ extension PostgresConnection { return try await promise.futureResult.map({ $0.asyncSequence() }).get() } } -#endif // MARK: EventLoopFuture interface @@ -785,7 +785,3 @@ extension PostgresConnection.InternalConfiguration { self.requireBackendKeyData = config.connection.requireBackendKeyData } } - -#if swift(>=5.6) -extension PostgresConnection: @unchecked Sendable {} -#endif diff --git a/Sources/PostgresNIO/Data/PostgresData.swift b/Sources/PostgresNIO/Data/PostgresData.swift index 1ae8af2f..0137ad87 100644 --- a/Sources/PostgresNIO/Data/PostgresData.swift +++ b/Sources/PostgresNIO/Data/PostgresData.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.UUID -public struct PostgresData: CustomStringConvertible, CustomDebugStringConvertible { +public struct PostgresData: Sendable, CustomStringConvertible, CustomDebugStringConvertible { public static var null: PostgresData { return .init(type: .null) } @@ -112,7 +112,3 @@ extension PostgresData: PostgresDataConvertible { return self } } - -#if swift(>=5.6) -extension PostgresData: Sendable {} -#endif diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index 55f529dc..50d2b0eb 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -3,7 +3,7 @@ /// Currently there a two wire formats supported: /// - text /// - binary -public enum PostgresFormat: Int16 { +public enum PostgresFormat: Int16, Sendable { case text = 0 case binary = 1 } @@ -17,11 +17,6 @@ extension PostgresFormat: CustomStringConvertible { } } -#if swift(>=5.6) -extension PostgresFormat: Sendable {} -#endif - - // TODO: The Codable conformance does not make any sense. Let's remove this with next major break. extension PostgresFormat: Codable {} @@ -31,7 +26,7 @@ public typealias PostgresFormatCode = PostgresFormat /// The data type's raw object ID. /// Use `select * from pg_type where oid = ;` to lookup more information. -public struct PostgresDataType: RawRepresentable, Hashable, CustomStringConvertible { +public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStringConvertible { /// `0` public static let null = PostgresDataType(0) /// `16` @@ -238,10 +233,6 @@ public struct PostgresDataType: RawRepresentable, Hashable, CustomStringConverti } } -#if swift(>=5.6) -extension PostgresDataType: Sendable {} -#endif - // TODO: The Codable conformance does not make any sense. Let's remove this with next major break. extension PostgresDataType: Codable {} diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 914667e5..74d13590 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -7,7 +7,7 @@ import class Foundation.JSONDecoder /// - Warning: Please note that random access to cells in a ``PostgresRow`` have O(n) time complexity. If you require /// random access to cells in O(1) create a new ``PostgresRandomAccessRow`` with the given row and /// access it instead. -public struct PostgresRow { +public struct PostgresRow: Sendable { @usableFromInline let lookupTable: [String: Int] @usableFromInline @@ -138,7 +138,7 @@ public struct PostgresRandomAccessRow { } } -extension PostgresRandomAccessRow: RandomAccessCollection { +extension PostgresRandomAccessRow: Sendable, RandomAccessCollection { public typealias Element = PostgresCell public typealias Index = Int @@ -320,9 +320,3 @@ extension PostgresRow: CustomStringConvertible { return row.description } } - -#if swift(>=5.6) -extension PostgresRow: Sendable {} - -extension PostgresRandomAccessRow: Sendable {} -#endif diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index 4cdc92f8..b181e600 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -9,7 +9,7 @@ import NIOCore /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler @usableFromInline -struct DataRow: PostgresBackendMessage.PayloadDecodable, Equatable { +struct DataRow: Sendable, PostgresBackendMessage.PayloadDecodable, Equatable { @usableFromInline var columnCount: Int16 @usableFromInline @@ -116,7 +116,3 @@ extension DataRow { return self[byteIndex] } } - -#if swift(>=5.5) -extension DataRow: Sendable {} -#endif diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index b6b0e614..730c2101 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -9,7 +9,7 @@ import NIOCore /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler. @usableFromInline -struct RowDescription: PostgresBackendMessage.PayloadDecodable, Equatable { +struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Equatable { /// Specifies the object ID of the parameter data type. @usableFromInline var columns: [Column] @@ -86,7 +86,3 @@ struct RowDescription: PostgresBackendMessage.PayloadDecodable, Equatable { return RowDescription(columns: result) } } - -#if swift(>=5.6) -extension RowDescription.Column: Sendable {} -#endif diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index c5a9cd3f..4c842275 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -1,7 +1,8 @@ import NIOCore import Logging -final class PSQLRowStream { +// Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop. +final class PSQLRowStream: @unchecked Sendable { private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source enum RowSource { @@ -23,10 +24,7 @@ final class PSQLRowStream { case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) case consumed(Result) - - #if canImport(_Concurrency) case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource) - #endif } internal let rowDescription: [RowDescription.Column] @@ -64,8 +62,7 @@ final class PSQLRowStream { } // MARK: Async Sequence - - #if canImport(_Concurrency) + func asyncSequence() -> PostgresRowSequence { self.eventLoop.preconditionInEventLoop() @@ -150,7 +147,6 @@ final class PSQLRowStream { preconditionFailure("Invalid state: \(self.downstreamState)") } } - #endif // MARK: Consume in array @@ -312,12 +308,10 @@ final class PSQLRowStream { self.downstreamState = .waitingForAll(rows, promise, dataSource) // immediately request more dataSource.request(for: self) - - #if canImport(_Concurrency) + case .asyncSequence(let consumer, let source): let yieldResult = consumer.yield(contentsOf: newRows) self.executeActionBasedOnYieldResult(yieldResult, source: source) - #endif case .consumed(.success): preconditionFailure("How can we receive further rows, if we are supposed to be done") @@ -353,12 +347,10 @@ final class PSQLRowStream { case .waitingForAll(let rows, let promise, _): self.downstreamState = .consumed(.success(commandTag)) promise.succeed(rows) - - #if canImport(_Concurrency) + case .asyncSequence(let source, _): source.finish() self.downstreamState = .consumed(.success(commandTag)) - #endif case .consumed: break @@ -380,13 +372,11 @@ final class PSQLRowStream { case .waitingForAll(_, let promise, _): self.downstreamState = .consumed(.failure(error)) promise.fail(error) - - #if canImport(_Concurrency) + case .asyncSequence(let consumer, _): consumer.finish(error) self.downstreamState = .consumed(.failure(error)) - #endif - + case .consumed: break } @@ -432,8 +422,3 @@ protocol PSQLRowsDataSource { func cancel(for stream: PSQLRowStream) } - -#if swift(>=5.5) -// Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop. -extension PSQLRowStream: @unchecked Sendable {} -#endif diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index 39710e8e..d3cf8d4e 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -1,7 +1,7 @@ import NIOCore /// A representation of a cell value within a ``PostgresRow`` and ``PostgresRandomAccessRow``. -public struct PostgresCell: Equatable { +public struct PostgresCell: Sendable, Equatable { /// The cell's value as raw bytes. public var bytes: ByteBuffer? /// The cell's data type. This is important metadata when decoding the cell. @@ -86,7 +86,3 @@ extension PostgresCell { try self.decode(T.self, context: .default, file: file, line: line) } } - -#if swift(>=5.6) -extension PostgresCell: Sendable {} -#endif diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 6f224895..1ba75050 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -1,7 +1,7 @@ import NIOCore /// A Postgres SQL query, that can be executed on a Postgres server. Contains the raw sql string and bindings. -public struct PostgresQuery: Hashable { +public struct PostgresQuery: Sendable, Hashable { /// The query string public var sql: String /// The query binds @@ -104,9 +104,9 @@ struct PSQLExecuteStatement { var rowDescription: RowDescription? } -public struct PostgresBindings: Hashable { +public struct PostgresBindings: Sendable, Hashable { @usableFromInline - struct Metadata: Hashable { + struct Metadata: Sendable, Hashable { @usableFromInline var dataType: PostgresDataType @usableFromInline @@ -179,9 +179,3 @@ public struct PostgresBindings: Hashable { self.metadata.append(.init(dataType: postgresData.type, format: .binary)) } } - -#if swift(>=5.6) -extension PostgresQuery: Sendable {} -extension PostgresBindings: Sendable {} -extension PostgresBindings.Metadata: Sendable {} -#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift index d7429ff8..ff212d0a 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift @@ -1,6 +1,5 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh -#if canImport(_Concurrency) extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient @@ -212,4 +211,3 @@ extension AsyncSequence where Element == PostgresRow { self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) } } -#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index 8248e14a..ccf4f69c 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -1,7 +1,6 @@ import NIOCore import NIOConcurrencyHelpers -#if canImport(_Concurrency) /// An async sequence of ``PostgresRow``s. /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. @@ -112,4 +111,3 @@ struct AdaptiveRowBuffer: NIOAsyncSequenceProducerBackPressureStrategy { return bufferDepth < self.target } } -#endif diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index b1a72e5f..6857e461 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -7,7 +7,6 @@ import NIOTransportServices import NIOPosix import NIOCore -#if canImport(_Concurrency) final class AsyncPostgresConnectionTests: XCTestCase { func test1kRoundTrips() async throws { @@ -164,4 +163,3 @@ extension XCTestCase { } } } -#endif diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 5cd69662..e1fdad11 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -6,7 +6,6 @@ import XCTest import NIOCore import Logging -#if canImport(_Concurrency) final class PostgresRowSequenceTests: XCTestCase { func testBackpressureWorks() async throws { @@ -467,4 +466,3 @@ final class MockRowDataSource: PSQLRowsDataSource { self._cancelCount.wrappingIncrement(ordering: .relaxed) } } -#endif From f21252b4fbf2d14ca180baed7420fe1f82ed0dfb Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 19 Apr 2023 16:21:10 +0200 Subject: [PATCH 136/292] Mark `RowDescription.Column` as Sendable (#338) --- Sources/PostgresNIO/New/Messages/RowDescription.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index 730c2101..66c71215 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -15,7 +15,7 @@ struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Equata var columns: [Column] @usableFromInline - struct Column: Equatable { + struct Column: Equatable, Sendable { /// The field name. @usableFromInline var name: String From 18a60efc950004d70dcc7b04650b3d2531b91210 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 20 Apr 2023 11:10:55 +0200 Subject: [PATCH 137/292] Make PSQLError public (#342) --- .../Connection/PostgresConnection.swift | 42 +- .../AuthenticationStateMachine.swift | 2 +- .../ConnectionStateMachine.swift | 12 +- Sources/PostgresNIO/New/PSQLError.swift | 431 +++++++++++++++--- .../New/PostgresBackendMessageDecoder.swift | 10 +- .../New/PostgresChannelHandler.swift | 6 +- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 39 +- Tests/IntegrationTests/PostgresNIOTests.swift | 2 +- .../AuthenticationStateMachineTests.swift | 2 +- .../ConnectionStateMachineTests.swift | 2 +- .../PSQLFrontendMessageDecoder.swift | 10 +- .../New/Messages/BackendKeyDataTests.swift | 2 +- .../Messages/NotificationResponseTests.swift | 4 +- .../Messages/ParameterDescriptionTests.swift | 4 +- .../New/Messages/ParameterStatusTests.swift | 4 +- .../New/Messages/ReadyForQueryTests.swift | 4 +- .../New/Messages/RowDescriptionTests.swift | 8 +- .../New/PSQLBackendMessageTests.swift | 6 +- 18 files changed, 453 insertions(+), 137 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index d98d2f17..2061e6bc 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -283,7 +283,7 @@ public final class PostgresConnection: @unchecked Sendable { case is PSQLError: throw error default: - throw PSQLError.channel(underlying: error) + throw PSQLError.connectionError(underlying: error) } } } @@ -312,7 +312,7 @@ public final class PostgresConnection: @unchecked Sendable { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.id)" guard query.binds.count <= Int(UInt16.max) else { - return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) + return self.channel.eventLoop.makeFailedFuture(PSQLError(code: .tooManyParameters, query: query)) } let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) @@ -344,7 +344,7 @@ public final class PostgresConnection: @unchecked Sendable { func execute(_ executeStatement: PSQLExecuteStatement, logger: Logger) -> EventLoopFuture { guard executeStatement.binds.count <= Int(UInt16.max) else { - return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters) + return self.channel.eventLoop.makeFailedFuture(PSQLError(code: .tooManyParameters)) } let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let context = ExtendedQueryContext( @@ -493,14 +493,14 @@ extension PostgresConnection { public func query( _ query: PostgresQuery, logger: Logger, - file: String = #file, + file: String = #fileID, line: Int = #line ) async throws -> PostgresRowSequence { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.id)" guard query.binds.count <= Int(UInt16.max) else { - throw PSQLError.tooManyParameters + throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) } let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let context = ExtendedQueryContext( @@ -511,7 +511,14 @@ extension PostgresConnection { self.channel.write(PSQLTask.extendedQuery(context), promise: nil) - return try await promise.futureResult.map({ $0.asyncSequence() }).get() + do { + return try await promise.futureResult.map({ $0.asyncSequence() }).get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = query + throw error // rethrow with more metadata + } } } @@ -530,7 +537,7 @@ extension PostgresConnection { public func query( _ query: PostgresQuery, logger: Logger, - file: String = #file, + file: String = #fileID, line: Int = #line ) -> EventLoopFuture { self.queryStream(query, logger: logger).flatMap { rowStream in @@ -540,7 +547,7 @@ extension PostgresConnection { } return PostgresQueryResult(metadata: metadata, rows: rows) } - } + }.enrichPSQLError(query: query, file: file, line: line) } /// Run a query on the Postgres server the connection is connected to and iterate the rows in a callback. @@ -557,7 +564,7 @@ extension PostgresConnection { public func query( _ query: PostgresQuery, logger: Logger, - file: String = #file, + file: String = #fileID, line: Int = #line, _ onRow: @escaping (PostgresRow) throws -> () ) -> EventLoopFuture { @@ -568,7 +575,7 @@ extension PostgresConnection { } return metadata } - } + }.enrichPSQLError(query: query, file: file, line: line) } } @@ -785,3 +792,18 @@ extension PostgresConnection.InternalConfiguration { self.requireBackendKeyData = config.connection.requireBackendKeyData } } + +extension EventLoopFuture { + func enrichPSQLError(query: PostgresQuery, file: String, line: Int) -> EventLoopFuture { + return self.flatMapErrorThrowing { error in + if var error = error as? PSQLError { + error.file = file + error.line = line + error.query = query + throw error + } else { + throw error + } + } + } +} diff --git a/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift index 859a4d4b..245e5efd 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift @@ -51,7 +51,7 @@ struct AuthenticationStateMachine { return .authenticated case .md5(let salt): guard self.authContext.password != nil else { - return self.setAndFireError(.authMechanismRequiresPassword) + return self.setAndFireError(PSQLError(code: .authMechanismRequiresPassword)) } self.state = .passwordAuthenticationSent return .sendPassword(.md5(salt: salt), self.authContext) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 91e6c007..eeab0a81 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1076,15 +1076,15 @@ extension ConnectionStateMachine { extension ConnectionStateMachine { func shouldCloseConnection(reason error: PSQLError) -> Bool { - switch error.base { + switch error.code.base { case .sslUnsupported: return true case .failedToAddSSLHandler: return true case .queryCancelled: return false - case .server(let message): - guard let sqlState = message.fields[.sqlState] else { + case .server: + guard let sqlState = error.serverInfo?[.sqlState] else { // any error message that doesn't have a sql state field, is unexpected by default. return true } @@ -1095,7 +1095,7 @@ extension ConnectionStateMachine { } return false - case .decoding: + case .messageDecodingFailure: return true case .unexpectedBackendMessage: return true @@ -1115,8 +1115,6 @@ extension ConnectionStateMachine { preconditionFailure("Pure client error, that is thrown directly and should never ") case .connectionError: return true - case .casting(_): - preconditionFailure("Pure client error, that is thrown directly in PSQLRows") case .uncleanShutdown: return true } @@ -1142,7 +1140,7 @@ extension ConnectionStateMachine { self.state = .error(error) var action = ConnectionAction.CleanUpContext.Action.close - if case .uncleanShutdown = error.base { + if case .uncleanShutdown = error.code.base { action = .fireChannelInactive } diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 2320c822..a2fa9b5b 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -1,92 +1,393 @@ import NIOCore -struct PSQLError: Error { +/// An error that is thrown from the PostgresClient. +public struct PSQLError: Error { - enum Base { - case sslUnsupported - case failedToAddSSLHandler(underlying: Error) - case server(PostgresBackendMessage.ErrorResponse) - case decoding(PSQLDecodingError) - case unexpectedBackendMessage(PostgresBackendMessage) - case unsupportedAuthMechanism(PSQLAuthScheme) - case authMechanismRequiresPassword - case saslError(underlyingError: Error) - case invalidCommandTag(String) + public struct Code: Sendable, Hashable, CustomStringConvertible { + enum Base: Sendable, Hashable { + case sslUnsupported + case failedToAddSSLHandler + case server + case messageDecodingFailure + case unexpectedBackendMessage + case unsupportedAuthMechanism + case authMechanismRequiresPassword + case saslError + case invalidCommandTag - case queryCancelled - case tooManyParameters - case connectionQuiescing - case connectionClosed - case connectionError(underlying: Error) - case uncleanShutdown + case queryCancelled + case tooManyParameters + case connectionQuiescing + case connectionClosed + case connectionError + case uncleanShutdown + } + + internal var base: Base + + private init(_ base: Base) { + self.base = base + } + + public static let sslUnsupported = Self.init(.sslUnsupported) + public static let failedToAddSSLHandler = Self(.failedToAddSSLHandler) + public static let server = Self(.server) + public static let messageDecodingFailure = Self(.messageDecodingFailure) + public static let unexpectedBackendMessage = Self(.unexpectedBackendMessage) + public static let unsupportedAuthMechanism = Self(.unsupportedAuthMechanism) + public static let authMechanismRequiresPassword = Self(.authMechanismRequiresPassword) + public static let saslError = Self.init(.saslError) + public static let invalidCommandTag = Self(.invalidCommandTag) + public static let queryCancelled = Self(.queryCancelled) + public static let tooManyParameters = Self(.tooManyParameters) + public static let connectionQuiescing = Self(.connectionQuiescing) + public static let connectionClosed = Self(.connectionClosed) + public static let connectionError = Self(.connectionError) + public static let uncleanShutdown = Self.init(.uncleanShutdown) - case casting(PostgresDecodingError) + public var description: String { + switch self.base { + case .sslUnsupported: + return "sslUnsupported" + case .failedToAddSSLHandler: + return "failedToAddSSLHandler" + case .server: + return "server" + case .messageDecodingFailure: + return "messageDecodingFailure" + case .unexpectedBackendMessage: + return "unexpectedBackendMessage" + case .unsupportedAuthMechanism: + return "unsupportedAuthMechanism" + case .authMechanismRequiresPassword: + return "authMechanismRequiresPassword" + case .saslError: + return "saslError" + case .invalidCommandTag: + return "invalidCommandTag" + case .queryCancelled: + return "queryCancelled" + case .tooManyParameters: + return "tooManyParameters" + case .connectionQuiescing: + return "connectionQuiescing" + case .connectionClosed: + return "connectionClosed" + case .connectionError: + return "connectionError" + case .uncleanShutdown: + return "uncleanShutdown" + } + } } - internal var base: Base + private var backing: Backing - private init(_ base: Base) { - self.base = base + private mutating func copyBackingStoriageIfNecessary() { + if !isKnownUniquelyReferenced(&self.backing) { + self.backing = self.backing.copy() + } } - static var sslUnsupported: PSQLError { - Self.init(.sslUnsupported) + /// The ``PSQLError/Code-swift.struct`` code + public internal(set) var code: Code { + get { self.backing.code } + set { + self.copyBackingStoriageIfNecessary() + self.backing.code = newValue + } } - static func failedToAddSSLHandler(underlying error: Error) -> PSQLError { - Self.init(.failedToAddSSLHandler(underlying: error)) + /// The info that was received from the server + public internal(set) var serverInfo: ServerInfo? { + get { self.backing.serverInfo } + set { + self.copyBackingStoriageIfNecessary() + self.backing.serverInfo = newValue + } } - static func server(_ message: PostgresBackendMessage.ErrorResponse) -> PSQLError { - Self.init(.server(message)) + /// The underlying error + public internal(set) var underlying: Error? { + get { self.backing.underlying } + set { + self.copyBackingStoriageIfNecessary() + self.backing.underlying = newValue + } } - static func decoding(_ error: PSQLDecodingError) -> PSQLError { - Self.init(.decoding(error)) + /// The file in which the Postgres operation was triggered that failed + public internal(set) var file: String? { + get { self.backing.file } + set { + self.copyBackingStoriageIfNecessary() + self.backing.file = newValue + } } - static func unexpectedBackendMessage(_ message: PostgresBackendMessage) -> PSQLError { - Self.init(.unexpectedBackendMessage(message)) + /// The line in which the Postgres operation was triggered that failed + public internal(set) var line: Int? { + get { self.backing.line } + set { + self.copyBackingStoriageIfNecessary() + self.backing.line = newValue + } } - static func unsupportedAuthMechanism(_ authScheme: PSQLAuthScheme) -> PSQLError { - Self.init(.unsupportedAuthMechanism(authScheme)) + /// The query that failed + public internal(set) var query: PostgresQuery? { + get { self.backing.query } + set { + self.copyBackingStoriageIfNecessary() + self.backing.query = newValue + } } - static var authMechanismRequiresPassword: PSQLError { - Self.init(.authMechanismRequiresPassword) + /// the backend message... we should keep this internal but we can use it to print more + /// advanced debug reasons. + var backendMessage: PostgresBackendMessage? { + get { self.backing.backendMessage } + set { + self.copyBackingStoriageIfNecessary() + self.backing.backendMessage = newValue + } } - static func sasl(underlying: Error) -> PSQLError { - Self.init(.saslError(underlyingError: underlying)) + /// the unsupported auth scheme... we should keep this internal but we can use it to print more + /// advanced debug reasons. + var unsupportedAuthScheme: UnsupportedAuthScheme? { + get { self.backing.unsupportedAuthScheme } + set { + self.copyBackingStoriageIfNecessary() + self.backing.unsupportedAuthScheme = newValue + } } - static func invalidCommandTag(_ value: String) -> PSQLError { - Self.init(.invalidCommandTag(value)) + /// the invalid command tag... we should keep this internal but we can use it to print more + /// advanced debug reasons. + var invalidCommandTag: String? { + get { self.backing.invalidCommandTag } + set { + self.copyBackingStoriageIfNecessary() + self.backing.invalidCommandTag = newValue + } + } + + init(code: Code, query: PostgresQuery, file: String? = nil, line: Int? = nil) { + self.backing = .init(code: code) + self.query = query + self.file = file + self.line = line + } + + init(code: Code) { + self.backing = .init(code: code) + } + + private final class Backing { + fileprivate var code: Code + + fileprivate var serverInfo: ServerInfo? + + fileprivate var underlying: Error? + + fileprivate var file: String? + + fileprivate var line: Int? + + fileprivate var query: PostgresQuery? + + fileprivate var backendMessage: PostgresBackendMessage? + + fileprivate var unsupportedAuthScheme: UnsupportedAuthScheme? + + fileprivate var invalidCommandTag: String? + + init(code: Code) { + self.code = code + } + + func copy() -> Self { + let new = Self.init(code: self.code) + new.serverInfo = self.serverInfo + new.underlying = self.underlying + new.file = self.file + new.line = self.line + new.query = self.query + new.backendMessage = self.backendMessage + return new + } + } + + public struct ServerInfo { + public struct Field: Hashable, Sendable { + fileprivate let backing: PostgresBackendMessage.Field + + private init(_ backing: PostgresBackendMessage.Field) { + self.backing = backing + } + + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a + /// localized translation of one of these. Always present. + public static let localizedSeverity = Self(.localizedSeverity) + + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message). + /// This is identical to the S field except that the contents are never localized. + /// This is present only in messages generated by PostgreSQL versions 9.6 and later. + public static let severity = Self(.severity) + + /// Code: the SQLSTATE code for the error (see Appendix A). Not localizable. Always present. + public static let sqlState = Self(.sqlState) + + /// Message: the primary human-readable error message. This should be accurate but terse (typically one line). + /// Always present. + public static let message = Self(.message) + + /// Detail: an optional secondary error message carrying more detail about the problem. + /// Might run to multiple lines. + public static let detail = Self(.detail) + + /// Hint: an optional suggestion what to do about the problem. + /// This is intended to differ from Detail in that it offers advice (potentially inappropriate) + /// rather than hard facts. Might run to multiple lines. + public static let hint = Self(.hint) + + /// Position: the field value is a decimal ASCII integer, indicating an error cursor + /// position as an index into the original query string. The first character has index 1, + /// and positions are measured in characters not bytes. + public static let position = Self(.position) + + /// Internal position: this is defined the same as the P field, but it is used when the + /// cursor position refers to an internally generated command rather than the one submitted by the client. + /// The q field will always appear when this field appears. + public static let internalPosition = Self(.internalPosition) + + /// Internal query: the text of a failed internally-generated command. + /// This could be, for example, a SQL query issued by a PL/pgSQL function. + public static let internalQuery = Self(.internalQuery) + + /// Where: an indication of the context in which the error occurred. + /// Presently this includes a call stack traceback of active procedural language functions and + /// internally-generated queries. The trace is one entry per line, most recent first. + public static let locationContext = Self(.locationContext) + + /// Schema name: if the error was associated with a specific database object, the name of + /// the schema containing that object, if any. + public static let schemaName = Self(.schemaName) + + /// Table name: if the error was associated with a specific table, the name of the table. + /// (Refer to the schema name field for the name of the table's schema.) + public static let tableName = Self(.tableName) + + /// Column name: if the error was associated with a specific table column, the name of the column. + /// (Refer to the schema and table name fields to identify the table.) + public static let columnName = Self(.columnName) + + /// Data type name: if the error was associated with a specific data type, the name of the data type. + /// (Refer to the schema name field for the name of the data type's schema.) + public static let dataTypeName = Self(.dataTypeName) + + /// Constraint name: if the error was associated with a specific constraint, the name of the constraint. + /// Refer to fields listed above for the associated table or domain. (For this purpose, indexes are + /// treated as constraints, even if they weren't created with constraint syntax.) + public static let constraintName = Self(.constraintName) + + /// File: the file name of the source-code location where the error was reported. + public static let file = Self(.file) + + /// Line: the line number of the source-code location where the error was reported. + public static let line = Self(.line) + + /// Routine: the name of the source-code routine reporting the error. + public static let routine = Self(.routine) + } + + let underlying: PostgresBackendMessage.ErrorResponse + + fileprivate init(_ underlying: PostgresBackendMessage.ErrorResponse) { + self.underlying = underlying + } + + /// The detailed server error information. This field is set if the ``PSQLError/code-swift.property`` is + /// ``PSQLError/Code-swift.struct/server``. + public subscript(field: Field) -> String? { + self.underlying.fields[field.backing] + } + } + + // MARK: - Internal convenience factory methods - + + static func unexpectedBackendMessage(_ message: PostgresBackendMessage) -> Self { + var new = Self(code: .unexpectedBackendMessage) + new.backendMessage = message + return new + } + + static func messageDecodingFailure(_ error: PostgresMessageDecodingError) -> Self { + var new = Self(code: .messageDecodingFailure) + new.underlying = error + return new + } + + static var connectionQuiescing: PSQLError { PSQLError(code: .connectionQuiescing) } + + static var connectionClosed: PSQLError { PSQLError(code: .connectionClosed) } + + static var authMechanismRequiresPassword: PSQLError { PSQLError(code: .authMechanismRequiresPassword) } + + static var sslUnsupported: PSQLError { PSQLError(code: .sslUnsupported) } + + static var queryCancelled: PSQLError { PSQLError(code: .queryCancelled) } + + static var uncleanShutdown: PSQLError { PSQLError(code: .uncleanShutdown) } + + static func server(_ response: PostgresBackendMessage.ErrorResponse) -> PSQLError { + var error = PSQLError(code: .server) + error.serverInfo = .init(response) + return error } - static var queryCancelled: PSQLError { - Self.init(.queryCancelled) + static func sasl(underlying: Error) -> PSQLError { + var error = PSQLError(code: .saslError) + error.underlying = underlying + return error } - static var tooManyParameters: PSQLError { - Self.init(.tooManyParameters) + static func failedToAddSSLHandler(underlying: Error) -> PSQLError { + var error = PSQLError(code: .failedToAddSSLHandler) + error.underlying = underlying + return error } - static var connectionQuiescing: PSQLError { - Self.init(.connectionQuiescing) + static func connectionError(underlying: Error) -> PSQLError { + var error = PSQLError(code: .connectionError) + error.underlying = underlying + return error } - static var connectionClosed: PSQLError { - Self.init(.connectionClosed) + static func unsupportedAuthMechanism(_ authScheme: UnsupportedAuthScheme) -> PSQLError { + var error = PSQLError(code: .unsupportedAuthMechanism) + error.unsupportedAuthScheme = authScheme + return error } - static func channel(underlying: Error) -> PSQLError { - Self.init(.connectionError(underlying: underlying)) + static func invalidCommandTag(_ value: String) -> PSQLError { + var error = PSQLError(code: .invalidCommandTag) + error.invalidCommandTag = value + return error } - static var uncleanShutdown: PSQLError { - Self.init(.uncleanShutdown) + enum UnsupportedAuthScheme { + case none + case kerberosV5 + case md5 + case plaintext + case scmCredential + case gss + case sspi + case sasl(mechanisms: [String]) } } @@ -110,25 +411,25 @@ public struct PostgresDecodingError: Error, Equatable { public static let failure = Self.init(.failure) } - /// The casting error code + /// The decoding error code public let code: Code - /// The cell's column name for which the casting failed + /// The cell's column name for which the decoding failed public let columnName: String - /// The cell's column index for which the casting failed + /// The cell's column index for which the decoding failed public let columnIndex: Int - /// The swift type the cell should have been casted into + /// The swift type the cell should have been decoded into public let targetType: Any.Type - /// The cell's postgres data type for which the casting failed + /// The cell's postgres data type for which the decoding failed public let postgresType: PostgresDataType - /// The cell's postgres format for which the casting failed + /// The cell's postgres format for which the decoding failed public let postgresFormat: PostgresFormat - /// A copy of the cell data which was attempted to be casted + /// A copy of the cell data which was attempted to be decoded public let postgresData: ByteBuffer? - /// The file the casting/decoding was attempted in + /// The file the decoding was attempted in public let file: String - /// The line the casting/decoding was attempted in + /// The line the decoding was attempted in public let line: Int @usableFromInline @@ -175,13 +476,3 @@ extension PostgresDecodingError: CustomStringConvertible { "Database error" } } -enum PSQLAuthScheme { - case none - case kerberosV5 - case md5 - case plaintext - case scmCredential - case gss - case sspi - case sasl(mechanisms: [String]) -} diff --git a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift index 076daa19..4e3b630e 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -56,7 +56,7 @@ struct PostgresBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { guard let messageID = PostgresBackendMessage.ID(rawValue: idByte) else { buffer.moveReaderIndex(to: startReaderIndex) let completeMessage = buffer.readSlice(length: Int(length) + 1)! - throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessage) + throw PostgresMessageDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessage) } // 3. decode the message @@ -69,7 +69,7 @@ struct PostgresBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { } catch let error as PSQLPartialDecodingError { buffer.moveReaderIndex(to: startReaderIndex) let completeMessage = buffer.readSlice(length: Int(length) + 1)! - throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessage) + throw PostgresMessageDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessage) } catch { preconditionFailure("Expected to only see `PartialDecodingError`s here.") } @@ -87,7 +87,7 @@ struct PostgresBackendMessageDecoder: NIOSingleStepByteToMessageDecoder { /// /// If you encounter a `DecodingError` when using a trusted Postgres server please make to file an issue at: /// [https://github.com/vapor/postgres-nio/issues](https://github.com/vapor/postgres-nio/issues) -struct PSQLDecodingError: Error { +struct PostgresMessageDecodingError: Error { /// The backend message ID bytes let messageID: UInt8 @@ -112,7 +112,7 @@ struct PSQLDecodingError: Error { var byteBuffer = messageBytes let data = byteBuffer.readData(length: byteBuffer.readableBytes)! - return PSQLDecodingError( + return PostgresMessageDecodingError( messageID: messageID, payload: data.base64EncodedString(), description: partialError.description, @@ -129,7 +129,7 @@ struct PSQLDecodingError: Error { var byteBuffer = messageBytes let data = byteBuffer.readData(length: byteBuffer.readableBytes)! - return PSQLDecodingError( + return PostgresMessageDecodingError( messageID: messageID, payload: data.base64EncodedString(), description: "Received a message with messageID '\(Character(UnicodeScalar(messageID)))'. There is no message type associated with this message identifier.", diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 089dbf7e..ec02cd2c 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -91,7 +91,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { func errorCaught(context: ChannelHandlerContext, error: Error) { self.logger.debug("Channel error caught.", metadata: [.error: "\(error)"]) - let action = self.state.errorHappened(.channel(underlying: error)) + let action = self.state.errorHappened(.connectionError(underlying: error)) self.run(action, with: context) } @@ -146,8 +146,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.run(action, with: context) } - } catch let error as PSQLDecodingError { - let action = self.state.errorHappened(.decoding(error)) + } catch let error as PostgresMessageDecodingError { + let action = self.state.errorHappened(.messageDecodingFailure(error)) self.run(action, with: context) } catch { preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.") diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 674b4273..55870f8a 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -2,40 +2,45 @@ import NIOCore extension PSQLError { func toPostgresError() -> Error { - switch self.base { + switch self.code.base { case .queryCancelled: return self - case .server(let errorMessage): + case .server: + guard let serverInfo = self.serverInfo else { + return self + } + var fields = [PostgresMessage.Error.Field: String]() - fields.reserveCapacity(errorMessage.fields.count) - errorMessage.fields.forEach { (key, value) in + fields.reserveCapacity(serverInfo.underlying.fields.count) + serverInfo.underlying.fields.forEach { (key, value) in fields[PostgresMessage.Error.Field(rawValue: key.rawValue)!] = value } return PostgresError.server(PostgresMessage.Error(fields: fields)) case .sslUnsupported: return PostgresError.protocol("Server does not support TLS") - case .failedToAddSSLHandler(underlying: let underlying): - return underlying - case .decoding(let decodingError): - return PostgresError.protocol("Error decoding message: \(decodingError)") - case .unexpectedBackendMessage(let message): + case .failedToAddSSLHandler: + return self.underlying ?? self + case .messageDecodingFailure: + let message = self.underlying != nil ? String(describing: self.underlying!) : "no message" + return PostgresError.protocol("Error decoding message: \(message)") + case .unexpectedBackendMessage: + let message = self.backendMessage != nil ? String(describing: self.backendMessage!) : "no message" return PostgresError.protocol("Unexpected message: \(message)") - case .unsupportedAuthMechanism(let authScheme): - return PostgresError.protocol("Unsupported auth scheme: \(authScheme)") + case .unsupportedAuthMechanism: + let message = self.unsupportedAuthScheme != nil ? String(describing: self.unsupportedAuthScheme!) : "no scheme" + return PostgresError.protocol("Unsupported auth scheme: \(message)") case .authMechanismRequiresPassword: return PostgresError.protocol("Unable to authenticate without password") - case .saslError(underlyingError: let underlying): - return underlying + case .saslError: + return self.underlying ?? self case .tooManyParameters, .invalidCommandTag: return self case .connectionQuiescing: return PostgresError.connectionClosed case .connectionClosed: return PostgresError.connectionClosed - case .connectionError(underlying: let underlying): - return underlying - case .casting(let castingError): - return castingError + case .connectionError: + return self.underlying ?? self case .uncleanShutdown: return PostgresError.protocol("Unexpected connection close") } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 8c84e280..61800463 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1093,7 +1093,7 @@ final class PostgresNIOTests: XCTestCase { defer { XCTAssertNoThrow( try conn?.close().wait() ) } let binds = [PostgresData].init(repeating: .null, count: Int(UInt16.max) + 1) XCTAssertThrowsError(try conn?.query("SELECT version()", binds).wait()) { error in - guard case .tooManyParameters = (error as? PSQLError)?.base else { + guard case .tooManyParameters = (error as? PSQLError)?.code.base else { return XCTFail("Unexpected error: \(error)") } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index 18fbc71b..87478e63 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -69,7 +69,7 @@ class AuthenticationStateMachineTests: XCTestCase { // MARK: Test unsupported messages func testUnsupportedAuthMechanism() { - let unsupported: [(PostgresBackendMessage.Authentication, PSQLAuthScheme)] = [ + let unsupported: [(PostgresBackendMessage.Authentication, PSQLError.UnsupportedAuthScheme)] = [ (.kerberosV5, .kerberosV5), (.scmCredential, .scmCredential), (.gss, .gss), diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index aeabc1fa..eaf427d5 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -132,7 +132,7 @@ class ConnectionStateMachineTests: XCTestCase { // test ignore unclean shutdown when closing connection var stateIgnoreChannelError = ConnectionStateMachine(.closing) - XCTAssertEqual(stateIgnoreChannelError.errorHappened(PSQLError.channel(underlying: NIOSSLError.uncleanShutdown)), .wait) + XCTAssertEqual(stateIgnoreChannelError.errorHappened(.connectionError(underlying: NIOSSLError.uncleanShutdown)), .wait) XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive) // test ignore any other error when closing connection diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 91471d86..fc3f8858 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -77,7 +77,7 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { return .startup(startup) default: - throw PSQLDecodingError.unknownStartupCodeReceived(code: code, messageBytes: messageSlice) + throw PostgresMessageDecodingError.unknownStartupCodeReceived(code: code, messageBytes: messageSlice) } } @@ -97,7 +97,7 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { // 2. make sure we have a known message identifier guard let messageID = PostgresFrontendMessage.ID(rawValue: idByte) else { - throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer) + throw PostgresMessageDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer) } // 3. decode the message @@ -109,7 +109,7 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { return try PostgresFrontendMessage.decode(from: &slice, for: messageID) } catch let error as PSQLPartialDecodingError { - throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer) + throw PostgresMessageDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer) } catch { preconditionFailure("Expected to only see `PartialDecodingError`s here.") } @@ -153,7 +153,7 @@ extension PostgresFrontendMessage { } } -extension PSQLDecodingError { +extension PostgresMessageDecodingError { static func unknownStartupCodeReceived( code: UInt32, messageBytes: ByteBuffer, @@ -163,7 +163,7 @@ extension PSQLDecodingError { var byteBuffer = messageBytes let data = byteBuffer.readData(length: byteBuffer.readableBytes)! - return PSQLDecodingError( + return PostgresMessageDecodingError( messageID: 0, payload: data.base64EncodedString(), description: "Received a startup code '\(code)'. There is no message associated with this code.", diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift index b67145c2..d41607e3 100644 --- a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -33,7 +33,7 @@ class BackendKeyDataTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expected, decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift index 7928e3f8..9a8a1529 100644 --- a/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/NotificationResponseTests.swift @@ -41,7 +41,7 @@ class NotificationResponseTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } @@ -56,7 +56,7 @@ class NotificationResponseTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift index dd42aea4..a6bc32a1 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterDescriptionTests.swift @@ -44,7 +44,7 @@ class ParameterDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } @@ -63,7 +63,7 @@ class ParameterDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift index ca4aa942..4513bbce 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParameterStatusTests.swift @@ -55,7 +55,7 @@ class ParameterStatusTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } @@ -69,7 +69,7 @@ class ParameterStatusTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift index e915be72..62a8c62f 100644 --- a/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ReadyForQueryTests.swift @@ -48,7 +48,7 @@ class ReadyForQueryTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } @@ -62,7 +62,7 @@ class ReadyForQueryTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index 899c88f1..4eed785a 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -60,7 +60,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } @@ -82,7 +82,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } @@ -105,7 +105,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } @@ -128,7 +128,7 @@ class RowDescriptionTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index d55e86bc..10e8503a 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -196,7 +196,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } } @@ -238,7 +238,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(failBuffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } } @@ -251,7 +251,7 @@ class PSQLBackendMessageTests: XCTestCase { XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PSQLDecodingError) + XCTAssert($0 is PostgresMessageDecodingError) } } From 606c68a52a610b719dcdbc652ebd778335f57ab6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 20 Apr 2023 12:10:10 +0200 Subject: [PATCH 138/292] Require new swift-nio versions and fix warnings (#343) --- Package.swift | 10 +++++----- .../PostgresNIO/New/Data/UUID+PostgresCodable.swift | 1 + Sources/PostgresNIO/New/PostgresCodable.swift | 3 ++- Tests/PostgresNIOTests/New/PostgresQueryTests.swift | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/Package.swift b/Package.swift index ea9c1c6b..afd064a1 100644 --- a/Package.swift +++ b/Package.swift @@ -13,13 +13,13 @@ let package = Package( .library(name: "PostgresNIO", targets: ["PostgresNIO"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.44.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.13.1"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.22.1"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.50.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.16.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.23.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.4.4"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.5.2"), ], targets: [ .target(name: "PostgresNIO", dependencies: [ diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index 632d5d93..e44d77e5 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -1,6 +1,7 @@ import NIOCore import struct Foundation.UUID import typealias Foundation.uuid_t +import NIOFoundationCompat extension UUID: PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index bd4e7f91..68291eac 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -1,5 +1,6 @@ import NIOCore -import Foundation +import class Foundation.JSONEncoder +import class Foundation.JSONDecoder /// A type that can encode itself to a postgres wire binary representation. public protocol PostgresEncodable { diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift index 926541f0..f50d414a 100644 --- a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -9,7 +9,7 @@ final class PostgresQueryTests: XCTestCase { let null: UUID? = nil let uuid: UUID? = UUID() - var query: PostgresQuery = """ + let query: PostgresQuery = """ INSERT INTO foo (id, title, something) SET (\(uuid), \(string), \(null)); """ From c996d6256d2a0406cc5a99efe5b5568b2a7f61bc Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 20 Apr 2023 13:04:39 +0200 Subject: [PATCH 139/292] Deprecating PostgresDataConvertible, PostgresMessageType (#313) --- .../PostgresNIO/Data/PostgresData+Array.swift | 12 ++--- .../PostgresNIO/Data/PostgresData+Bool.swift | 1 + .../PostgresNIO/Data/PostgresData+Bytes.swift | 1 + .../PostgresNIO/Data/PostgresData+Date.swift | 1 + .../Data/PostgresData+Decimal.swift | 1 + .../Data/PostgresData+Double.swift | 1 + .../PostgresNIO/Data/PostgresData+Float.swift | 1 + .../PostgresNIO/Data/PostgresData+Int.swift | 5 ++ .../PostgresNIO/Data/PostgresData+JSON.swift | 2 + .../PostgresNIO/Data/PostgresData+JSONB.swift | 2 + .../Data/PostgresData+Optional.swift | 1 + .../Data/PostgresData+RawRepresentable.swift | 1 + .../PostgresNIO/Data/PostgresData+Set.swift | 1 + .../Data/PostgresData+String.swift | 1 + .../PostgresNIO/Data/PostgresData+UUID.swift | 1 + Sources/PostgresNIO/Data/PostgresData.swift | 11 +++- .../Data/PostgresDataConvertible.swift | 1 + Sources/PostgresNIO/Data/PostgresRow.swift | 2 + .../PostgresMessage+SASLResponse.swift | 54 ++++++++++--------- Sources/PostgresNIO/Docs.docc/index.md | 4 ++ .../Message/PostgresMessage+0.swift | 3 ++ .../PostgresMessage+BackendKeyData.swift | 35 ++++++------ .../Message/PostgresMessage+DataRow.swift | 47 ++++++++-------- .../Message/PostgresMessage+Error.swift | 37 +++++++------ .../Message/PostgresMessage+Identifier.swift | 2 + ...PostgresMessage+NotificationResponse.swift | 37 +++++++------ .../PostgresMessage+RowDescription.swift | 35 ++++++------ .../Message/PostgresMessageType.swift | 6 ++- Tests/IntegrationTests/PerformanceTests.swift | 2 + Tests/IntegrationTests/PostgresNIOTests.swift | 15 +++++- .../Data/PostgresData+JSONTests.swift | 1 + 31 files changed, 201 insertions(+), 123 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresData+Array.swift b/Sources/PostgresNIO/Data/PostgresData+Array.swift index d0c1c6f4..5d648db6 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Array.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Array.swift @@ -1,14 +1,14 @@ import NIOCore extension PostgresData { - public init(array: [T]) - where T: PostgresDataConvertible - { + @available(*, deprecated, message: "Use ``PostgresQuery`` and ``PostgresBindings`` instead.") + public init(array: [T]) where T: PostgresDataConvertible { self.init( array: array.map { $0.postgresData }, elementType: T.postgresDataType ) } + public init(array: [PostgresData?], elementType: PostgresDataType) { var buffer = ByteBufferAllocator().buffer(capacity: 0) // 0 if empty, 1 if not @@ -46,9 +46,8 @@ extension PostgresData { ) } - public func array(of type: T.Type = T.self) -> [T]? - where T: PostgresDataConvertible - { + @available(*, deprecated, message: "Use ``PostgresRow`` and ``PostgresDecodable`` instead.") + public func array(of type: T.Type = T.self) -> [T]? where T: PostgresDataConvertible { guard let array = self.array else { return nil } @@ -114,6 +113,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Array: PostgresDataConvertible where Element: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { guard let arrayType = Element.postgresDataType.arrayType else { diff --git a/Sources/PostgresNIO/Data/PostgresData+Bool.swift b/Sources/PostgresNIO/Data/PostgresData+Bool.swift index 99e0c670..0b9f2738 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Bool.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Bool.swift @@ -47,6 +47,7 @@ extension PostgresData: ExpressibleByBooleanLiteral { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Bool: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .bool diff --git a/Sources/PostgresNIO/Data/PostgresData+Bytes.swift b/Sources/PostgresNIO/Data/PostgresData+Bytes.swift index 292c3c0a..5ec507cd 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Bytes.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Bytes.swift @@ -21,6 +21,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Data: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .bytea diff --git a/Sources/PostgresNIO/Data/PostgresData+Date.swift b/Sources/PostgresNIO/Data/PostgresData+Date.swift index 86fa2f17..6d730f25 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Date.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Date.swift @@ -36,6 +36,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Date: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .timestamptz diff --git a/Sources/PostgresNIO/Data/PostgresData+Decimal.swift b/Sources/PostgresNIO/Data/PostgresData+Decimal.swift index 0d2047b6..3af709e5 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Decimal.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Decimal.swift @@ -16,6 +16,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Decimal: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .numeric diff --git a/Sources/PostgresNIO/Data/PostgresData+Double.swift b/Sources/PostgresNIO/Data/PostgresData+Double.swift index 986f8e23..2d7735ef 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Double.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Double.swift @@ -34,6 +34,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Double: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .float8 diff --git a/Sources/PostgresNIO/Data/PostgresData+Float.swift b/Sources/PostgresNIO/Data/PostgresData+Float.swift index 9931ae9c..45430934 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Float.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Float.swift @@ -28,6 +28,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Float: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .float4 diff --git a/Sources/PostgresNIO/Data/PostgresData+Int.swift b/Sources/PostgresNIO/Data/PostgresData+Int.swift index 4729021f..5a97b3fb 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Int.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Int.swift @@ -183,6 +183,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Int: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .int8 } @@ -198,6 +199,7 @@ extension Int: PostgresDataConvertible { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension UInt8: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .char } @@ -213,6 +215,7 @@ extension UInt8: PostgresDataConvertible { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Int16: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .int2 } @@ -228,6 +231,7 @@ extension Int16: PostgresDataConvertible { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Int32: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .int4 } @@ -243,6 +247,7 @@ extension Int32: PostgresDataConvertible { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Int64: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { .int8 } diff --git a/Sources/PostgresNIO/Data/PostgresData+JSON.swift b/Sources/PostgresNIO/Data/PostgresData+JSON.swift index 519b721d..53a2d84c 100644 --- a/Sources/PostgresNIO/Data/PostgresData+JSON.swift +++ b/Sources/PostgresNIO/Data/PostgresData+JSON.swift @@ -37,8 +37,10 @@ extension PostgresData { } } +@available(*, deprecated, message: "This protocol is going to be replaced with ``PostgresEncodable`` and ``PostgresDecodable`` and conforming to ``Codable`` at the same time") public protocol PostgresJSONCodable: Codable, PostgresDataConvertible { } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension PostgresJSONCodable { public static var postgresDataType: PostgresDataType { return .json diff --git a/Sources/PostgresNIO/Data/PostgresData+JSONB.swift b/Sources/PostgresNIO/Data/PostgresData+JSONB.swift index 0b374ba6..0d5befa3 100644 --- a/Sources/PostgresNIO/Data/PostgresData+JSONB.swift +++ b/Sources/PostgresNIO/Data/PostgresData+JSONB.swift @@ -48,8 +48,10 @@ extension PostgresData { } } +@available(*, deprecated, message: "This protocol is going to be replaced with ``PostgresEncodable`` and ``PostgresDecodable`` and conforming to ``Codable`` at the same time") public protocol PostgresJSONBCodable: Codable, PostgresDataConvertible { } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension PostgresJSONBCodable { public static var postgresDataType: PostgresDataType { return .jsonb diff --git a/Sources/PostgresNIO/Data/PostgresData+Optional.swift b/Sources/PostgresNIO/Data/PostgresData+Optional.swift index 6b492054..9468478a 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Optional.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Optional.swift @@ -1,3 +1,4 @@ +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Optional: PostgresDataConvertible where Wrapped: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return Wrapped.postgresDataType diff --git a/Sources/PostgresNIO/Data/PostgresData+RawRepresentable.swift b/Sources/PostgresNIO/Data/PostgresData+RawRepresentable.swift index 68e090ea..6cc8316a 100644 --- a/Sources/PostgresNIO/Data/PostgresData+RawRepresentable.swift +++ b/Sources/PostgresNIO/Data/PostgresData+RawRepresentable.swift @@ -1,3 +1,4 @@ +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension RawRepresentable where Self.RawValue: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { RawValue.postgresDataType diff --git a/Sources/PostgresNIO/Data/PostgresData+Set.swift b/Sources/PostgresNIO/Data/PostgresData+Set.swift index 1a7cd0c1..ade48db1 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Set.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Set.swift @@ -1,3 +1,4 @@ +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension Set: PostgresDataConvertible where Element: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { [Element].postgresDataType diff --git a/Sources/PostgresNIO/Data/PostgresData+String.swift b/Sources/PostgresNIO/Data/PostgresData+String.swift index 66a08337..f38e2ab8 100644 --- a/Sources/PostgresNIO/Data/PostgresData+String.swift +++ b/Sources/PostgresNIO/Data/PostgresData+String.swift @@ -94,6 +94,7 @@ extension PostgresData: ExpressibleByStringLiteral { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension String: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .text diff --git a/Sources/PostgresNIO/Data/PostgresData+UUID.swift b/Sources/PostgresNIO/Data/PostgresData+UUID.swift index f899b345..7c2da080 100644 --- a/Sources/PostgresNIO/Data/PostgresData+UUID.swift +++ b/Sources/PostgresNIO/Data/PostgresData+UUID.swift @@ -29,6 +29,7 @@ extension PostgresData { } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension UUID: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { return .uuid diff --git a/Sources/PostgresNIO/Data/PostgresData.swift b/Sources/PostgresNIO/Data/PostgresData.swift index 0137ad87..d0be48eb 100644 --- a/Sources/PostgresNIO/Data/PostgresData.swift +++ b/Sources/PostgresNIO/Data/PostgresData.swift @@ -1,7 +1,7 @@ import NIOCore import struct Foundation.UUID -public struct PostgresData: Sendable, CustomStringConvertible, CustomDebugStringConvertible { +public struct PostgresData: Sendable { public static var null: PostgresData { return .init(type: .null) } @@ -26,7 +26,10 @@ public struct PostgresData: Sendable, CustomStringConvertible, CustomDebugString self.formatCode = formatCode self.value = value } - +} + +@available(*, deprecated, message: "Deprecating conformance to `CustomStringConvertible` as a first step of deprecating `PostgresData`. Please use `PostgresBindings` or `PostgresCell` instead.") +extension PostgresData: CustomStringConvertible { public var description: String { guard var value = self.value else { return "" @@ -93,12 +96,16 @@ public struct PostgresData: Sendable, CustomStringConvertible, CustomDebugString return "\(raw) (\(self.type))" } } +} +@available(*, deprecated, message: "Deprecating conformance to `CustomDebugStringConvertible` as a first step of deprecating `PostgresData`. Please use `PostgresBindings` or `PostgresCell` instead.") +extension PostgresData: CustomDebugStringConvertible { public var debugDescription: String { return self.description } } +@available(*, deprecated, message: "Deprecating conformance to `PostgresDataConvertible`, since it is deprecated.") extension PostgresData: PostgresDataConvertible { public static var postgresDataType: PostgresDataType { fatalError("PostgresData cannot be statically represented as a single data type") diff --git a/Sources/PostgresNIO/Data/PostgresDataConvertible.swift b/Sources/PostgresNIO/Data/PostgresDataConvertible.swift index 32e7fc41..675ed6fe 100644 --- a/Sources/PostgresNIO/Data/PostgresDataConvertible.swift +++ b/Sources/PostgresNIO/Data/PostgresDataConvertible.swift @@ -1,5 +1,6 @@ import Foundation +@available(*, deprecated, message: "This protocol is going to be replaced with ``PostgresEncodable`` and ``PostgresDecodable``") public protocol PostgresDataConvertible { static var postgresDataType: PostgresDataType { get } init?(postgresData: PostgresData) diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index 74d13590..af7758f4 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -265,6 +265,7 @@ extension PostgresRandomAccessRow { // MARK: Deprecated API extension PostgresRow { + @available(*, deprecated, message: "Will be removed from public API.") public var rowDescription: PostgresMessage.RowDescription { let fields = self.columns.map { column in PostgresMessage.RowDescription.Field( @@ -280,6 +281,7 @@ extension PostgresRow { return PostgresMessage.RowDescription(fields: fields) } + @available(*, deprecated, message: "Iterate the cells on `PostgresRow` instead.") public var dataRow: PostgresMessage.DataRow { let columns = self.data.map { PostgresMessage.DataRow.Column(value: $0) diff --git a/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift b/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift index dba414ce..dc3b1772 100644 --- a/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresMessage+SASLResponse.swift @@ -30,35 +30,13 @@ extension PostgresMessage { extension PostgresMessage { /// SASL initial challenge response message sent by the client. - public struct SASLInitialResponse: PostgresMessageType { - public static var identifier: PostgresMessage.Identifier { - return .saslInitialResponse - } - + @available(*, deprecated, message: "Will be removed from public API") + public struct SASLInitialResponse { public let mechanism: String public let initialData: [UInt8] - public static func parse(from buffer: inout ByteBuffer) throws -> PostgresMessage.SASLInitialResponse { - guard let mechanism = buffer.readNullTerminatedString() else { - throw PostgresError.protocol("Could not parse SASL mechanism from initial response message") - } - guard let dataLength = buffer.readInteger(as: Int32.self) else { - throw PostgresError.protocol("Could not parse SASL initial data length from initial response message") - } - - var actualData: [UInt8] = [] - - if dataLength != -1 { - guard let data = buffer.readBytes(length: Int(dataLength)) else { - throw PostgresError.protocol("Could not parse SASL initial data from initial response message") - } - actualData = data - } - return SASLInitialResponse(mechanism: mechanism, initialData: actualData) - } - public func serialize(into buffer: inout ByteBuffer) throws { - buffer.writeNullTerminatedString(mechanism) + buffer.writeNullTerminatedString(self.mechanism) if initialData.count > 0 { buffer.writeInteger(Int32(initialData.count), as: Int32.self) // write(array:) writes Int16, which is incorrect here buffer.writeBytes(initialData) @@ -72,3 +50,29 @@ extension PostgresMessage { } } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.SASLInitialResponse: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .saslInitialResponse + } + + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let mechanism = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Could not parse SASL mechanism from initial response message") + } + guard let dataLength = buffer.readInteger(as: Int32.self) else { + throw PostgresError.protocol("Could not parse SASL initial data length from initial response message") + } + + var actualData: [UInt8] = [] + + if dataLength != -1 { + guard let data = buffer.readBytes(length: Int(dataLength)) else { + throw PostgresError.protocol("Could not parse SASL initial data from initial response message") + } + actualData = data + } + return .init(mechanism: mechanism, initialData: actualData) + } +} diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md index 6b7fd5b0..e7363054 100644 --- a/Sources/PostgresNIO/Docs.docc/index.md +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -45,6 +45,7 @@ Features: - ``PostgresJSONEncoder`` - ``PostgresJSONDecoder`` - ``PostgresDataType`` +- ``PostgresFormat`` - ``PostgresNumeric`` ### Notifications @@ -72,8 +73,11 @@ removed from the public API with the next major release. - ``PostgresRequest`` - ``PostgresMessage`` - ``PostgresMessageType`` +- ``PostgresFormatCode`` - ``SASLAuthenticationManager`` - ``SASLAuthenticationMechanism`` +- ``SASLAuthenticationError`` +- ``SASLAuthenticationStepResult`` [SwiftNIO]: https://github.com/apple/swift-nio [SwiftLog]: https://github.com/apple/swift-log diff --git a/Sources/PostgresNIO/Message/PostgresMessage+0.swift b/Sources/PostgresNIO/Message/PostgresMessage+0.swift index f33e89a3..386ffd34 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+0.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+0.swift @@ -2,9 +2,11 @@ import NIOCore /// A frontend or backend Postgres message. public struct PostgresMessage: Equatable { + @available(*, deprecated, message: "Will be removed from public API.") public var identifier: Identifier public var data: ByteBuffer + @available(*, deprecated, message: "Will be removed from public API.") public init(identifier: Identifier, bytes: Data) where Data: Sequence, Data.Element == UInt8 { @@ -13,6 +15,7 @@ public struct PostgresMessage: Equatable { self.init(identifier: identifier, data: buffer) } + @available(*, deprecated, message: "Will be removed from public API.") public init(identifier: Identifier, data: ByteBuffer) { self.identifier = identifier self.data = data diff --git a/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift b/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift index 85c2277a..63a6af7d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+BackendKeyData.swift @@ -3,22 +3,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as cancellation key data. /// The frontend must save these values if it wishes to be able to issue CancelRequest messages later. - public struct BackendKeyData: PostgresMessageType { - public static var identifier: PostgresMessage.Identifier { - .backendKeyData - } - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> BackendKeyData { - guard let processID = buffer.readInteger(as: Int32.self) else { - throw PostgresError.protocol("Could not parse process id from backend key data") - } - guard let secretKey = buffer.readInteger(as: Int32.self) else { - throw PostgresError.protocol("Could not parse secret key from backend key data") - } - return .init(processID: processID, secretKey: secretKey) - } - + public struct BackendKeyData { /// The process ID of this backend. public var processID: Int32 @@ -26,3 +11,21 @@ extension PostgresMessage { public var secretKey: Int32 } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.BackendKeyData: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + .backendKeyData + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let processID = buffer.readInteger(as: Int32.self) else { + throw PostgresError.protocol("Could not parse process id from backend key data") + } + guard let secretKey = buffer.readInteger(as: Int32.self) else { + throw PostgresError.protocol("Could not parse secret key from backend key data") + } + return .init(processID: processID, secretKey: secretKey) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift b/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift index e5cc3d9d..655bfb1e 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+DataRow.swift @@ -2,11 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a data row. - public struct DataRow: PostgresMessageType { - public static var identifier: PostgresMessage.Identifier { - return .dataRow - } - + public struct DataRow { public struct Column: CustomStringConvertible { /// The length of the column value, in bytes (this count does not include itself). /// Can be zero. As a special case, -1 indicates a NULL column value. No value bytes follow in the NULL case. @@ -23,23 +19,7 @@ extension PostgresMessage { } } } - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> DataRow { - guard let columns = buffer.read(array: Column.self, { buffer in - if var slice = buffer.readNullableBytes() { - var copy = ByteBufferAllocator().buffer(capacity: slice.readableBytes) - copy.writeBuffer(&slice) - return .init(value: copy) - } else { - return .init(value: nil) - } - }) else { - throw PostgresError.protocol("Could not parse data row columns") - } - return .init(columns: columns) - } - + /// The data row's columns public var columns: [Column] @@ -49,3 +29,26 @@ extension PostgresMessage { } } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.DataRow: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .dataRow + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let columns = buffer.read(array: Column.self, { buffer in + if var slice = buffer.readNullableBytes() { + var copy = ByteBufferAllocator().buffer(capacity: slice.readableBytes) + copy.writeBuffer(&slice) + return .init(value: copy) + } else { + return .init(value: nil) + } + }) else { + throw PostgresError.protocol("Could not parse data row columns") + } + return .init(columns: columns) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift index 51b9be7e..44f9e6bf 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift @@ -2,23 +2,7 @@ import NIOCore extension PostgresMessage { /// First message sent from the frontend during startup. - public struct Error: PostgresMessageType, CustomStringConvertible { - public static var identifier: PostgresMessage.Identifier { - return .error - } - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> Error { - var fields: [Field: String] = [:] - while let field = buffer.readInteger(as: Field.self) { - guard let string = buffer.readNullTerminatedString() else { - throw PostgresError.protocol("Could not read error response string.") - } - fields[field] = string - } - return .init(fields: fields) - } - + public struct Error: CustomStringConvertible { public enum Field: UInt8, Hashable { /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a @@ -108,3 +92,22 @@ extension PostgresMessage { } } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.Error: PostgresMessageType { + public static var identifier: PostgresMessage.Identifier { + return .error + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + var fields: [Field: String] = [:] + while let field = buffer.readInteger(as: Field.self) { + guard let string = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Could not read error response string.") + } + fields[field] = string + } + return .init(fields: fields) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift index 3c0c3ef0..786b91ef 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift @@ -3,6 +3,7 @@ import NIOCore extension PostgresMessage { /// Identifies an incoming or outgoing postgres message. Sent as the first byte, before the message size. /// Values are not unique across all identifiers, meaning some messages will require keeping state to identify. + @available(*, deprecated, message: "Will be removed from public API.") public struct Identifier: ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { // special public static let none: Identifier = 0x00 @@ -143,6 +144,7 @@ extension PostgresMessage { } extension ByteBuffer { + @available(*, deprecated, message: "Will be removed from public API") mutating func write(identifier: PostgresMessage.Identifier) { self.writeInteger(identifier.value) } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift index 4979e354..1a3b596d 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+NotificationResponse.swift @@ -2,25 +2,28 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a notification response. - public struct NotificationResponse: PostgresMessageType { - public static let identifier = Identifier.notificationResponse - - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> Self { - guard let backendPID: Int32 = buffer.readInteger() else { - throw PostgresError.protocol("Invalid NotificationResponse message: unable to read backend PID") - } - guard let channel = buffer.readNullTerminatedString() else { - throw PostgresError.protocol("Invalid NotificationResponse message: unable to read channel") - } - guard let payload = buffer.readNullTerminatedString() else { - throw PostgresError.protocol("Invalid NotificationResponse message: unable to read payload") - } - return .init(backendPID: backendPID, channel: channel, payload: payload) - } - + public struct NotificationResponse { public var backendPID: Int32 public var channel: String public var payload: String } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.NotificationResponse: PostgresMessageType { + public static let identifier = PostgresMessage.Identifier.notificationResponse + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let backendPID: Int32 = buffer.readInteger() else { + throw PostgresError.protocol("Invalid NotificationResponse message: unable to read backend PID") + } + guard let channel = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Invalid NotificationResponse message: unable to read channel") + } + guard let payload = buffer.readNullTerminatedString() else { + throw PostgresError.protocol("Invalid NotificationResponse message: unable to read payload") + } + return .init(backendPID: backendPID, channel: channel, payload: payload) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift index ee8fa919..5713cc99 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift @@ -2,12 +2,7 @@ import NIOCore extension PostgresMessage { /// Identifies the message as a row description. - public struct RowDescription: PostgresMessageType { - /// See `PostgresMessageType`. - public static var identifier: PostgresMessage.Identifier { - return .rowDescription - } - + public struct RowDescription { /// Describes a single field returns in a `RowDescription` message. public struct Field: CustomStringConvertible { static func parse(from buffer: inout ByteBuffer) throws -> Field { @@ -73,15 +68,7 @@ extension PostgresMessage { } } - /// Parses an instance of this message type from a byte buffer. - public static func parse(from buffer: inout ByteBuffer) throws -> RowDescription { - guard let fields = try buffer.read(array: Field.self, { buffer in - return try.parse(from: &buffer) - }) else { - throw PostgresError.protocol("Could not read row description fields") - } - return .init(fields: fields) - } + /// The fields supplied in the row description. public var fields: [Field] @@ -92,3 +79,21 @@ extension PostgresMessage { } } } + +@available(*, deprecated, message: "Deprecating conformance to `PostgresMessageType` since it is deprecated.") +extension PostgresMessage.RowDescription: PostgresMessageType { + /// See `PostgresMessageType`. + public static var identifier: PostgresMessage.Identifier { + return .rowDescription + } + + /// Parses an instance of this message type from a byte buffer. + public static func parse(from buffer: inout ByteBuffer) throws -> Self { + guard let fields = try buffer.read(array: Field.self, { buffer in + return try.parse(from: &buffer) + }) else { + throw PostgresError.protocol("Could not read row description fields") + } + return .init(fields: fields) + } +} diff --git a/Sources/PostgresNIO/Message/PostgresMessageType.swift b/Sources/PostgresNIO/Message/PostgresMessageType.swift index 604da4b9..170c4aec 100644 --- a/Sources/PostgresNIO/Message/PostgresMessageType.swift +++ b/Sources/PostgresNIO/Message/PostgresMessageType.swift @@ -1,12 +1,15 @@ import NIOCore +@available(*, deprecated, message: "Will be removed from public API. Internally we now use `PostgresBackendMessage` and `PostgresFrontendMessage`") public protocol PostgresMessageType { static var identifier: PostgresMessage.Identifier { get } static func parse(from buffer: inout ByteBuffer) throws -> Self func serialize(into buffer: inout ByteBuffer) throws } +@available(*, deprecated, message: "`PostgresMessageType` protocol is deprecated.") extension PostgresMessageType { + @available(*, deprecated, message: "Will be removed from public API.") func message() throws -> PostgresMessage { var buffer = ByteBufferAllocator().buffer(capacity: 0) try self.serialize(into: &buffer) @@ -17,7 +20,8 @@ extension PostgresMessageType { var message = message self = try Self.parse(from: &message.data) } - + + @available(*, deprecated, message: "Will be removed from public API.") public static var identifier: PostgresMessage.Identifier { return .none } diff --git a/Tests/IntegrationTests/PerformanceTests.swift b/Tests/IntegrationTests/PerformanceTests.swift index 59a2392a..5d30db5e 100644 --- a/Tests/IntegrationTests/PerformanceTests.swift +++ b/Tests/IntegrationTests/PerformanceTests.swift @@ -73,6 +73,7 @@ final class PerformanceTests: XCTestCase { } } + @available(*, deprecated, message: "Testing deprecated functionality") func testPerformanceSelectMediumModel() throws { let conn = try PostgresConnection.test(on: eventLoop).wait() defer { try! conn.close().wait() } @@ -115,6 +116,7 @@ final class PerformanceTests: XCTestCase { } } + @available(*, deprecated, message: "Testing deprecated functionality") func testPerformanceSelectLargeModel() throws { let conn = try PostgresConnection.test(on: eventLoop).wait() defer { try! conn.close().wait() } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 61800463..114ae2bc 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -531,6 +531,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "e"].string, "12345678.90") } + @available(*, deprecated, message: "Testing deprecated functionality") func testIntegerArrayParse() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -544,6 +545,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "array"].array(of: Int.self), [1, 2, 3]) } + @available(*, deprecated, message: "Testing deprecated functionality") func testEmptyIntegerArrayParse() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -556,7 +558,8 @@ final class PostgresNIOTests: XCTestCase { let row = rows?.first?.makeRandomAccess() XCTAssertEqual(row?[data: "array"].array(of: Int.self), []) } - + + @available(*, deprecated, message: "Testing deprecated functionality") func testOptionalIntegerArrayParse() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -570,6 +573,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "array"].array(of: Int?.self), [1, 2, nil, 4]) } + @available(*, deprecated, message: "Testing deprecated functionality") func testNullIntegerArrayParse() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -583,6 +587,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "array"].array(of: Int.self), nil) } + @available(*, deprecated, message: "Testing deprecated functionality") func testIntegerArraySerialize() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -598,6 +603,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "array"].array(of: Int.self), [1, 2, 3]) } + @available(*, deprecated, message: "Testing deprecated functionality") func testEmptyIntegerArraySerialize() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -612,7 +618,8 @@ final class PostgresNIOTests: XCTestCase { let row = rows?.first?.makeRandomAccess() XCTAssertEqual(row?[data: "array"].array(of: Int.self), []) } - + + @available(*, deprecated, message: "Testing deprecated functionality") func testOptionalIntegerArraySerialize() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -855,6 +862,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "t2_dateValue"].date, dateInTable2) } + @available(*, deprecated, message: "Testing deprecated functionality") func testStringArrays() { let query = """ SELECT @@ -936,6 +944,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "char"].string, "*") } + @available(*, deprecated, message: "Testing deprecated functionality") func testDoubleArraySerialization() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -1057,6 +1066,7 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "foo"].string, "qux") } + @available(*, deprecated, message: "Testing deprecated functionality") func testNullBind() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -1106,6 +1116,7 @@ final class PostgresNIOTests: XCTestCase { } // https://github.com/vapor/postgres-nio/issues/113 + @available(*, deprecated, message: "Testing deprecated functionality") func testVaryingCharArray() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) diff --git a/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift b/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift index a8287966..47dd89a1 100644 --- a/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift +++ b/Tests/PostgresNIOTests/Data/PostgresData+JSONTests.swift @@ -2,6 +2,7 @@ import PostgresNIO import XCTest class PostgresData_JSONTests: XCTestCase { + @available(*, deprecated, message: "Testing deprecated functionality") func testJSONBConvertible() { struct Object: PostgresJSONBCodable { let foo: Int From 98b8e1b1488c706f8bff9eb07560745630a64679 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 20 Apr 2023 06:53:17 -0500 Subject: [PATCH 140/292] Add support for UDS and existing Channels (#335) --- .github/workflows/test.yml | 20 +- .../PostgresConnection+Configuration.swift | 276 ++++++++++++++++++ .../Connection/PostgresConnection.swift | 208 ++----------- ...sConnection+Configuration+Deprecated.swift | 95 ++++++ .../New/Data/UUID+PostgresCodable.swift | 1 + .../New/PostgresChannelHandler.swift | 28 +- Sources/PostgresNIO/New/PostgresCodable.swift | 2 +- Sources/PostgresNIO/Utilities/Exports.swift | 8 +- .../PSQLIntegrationTests.swift | 14 +- Tests/IntegrationTests/PostgresNIOTests.swift | 49 +++- Tests/IntegrationTests/Utilities.swift | 46 ++- .../New/PSQLConnectionTests.swift | 4 +- .../New/PostgresChannelHandlerTests.swift | 52 ++-- 13 files changed, 539 insertions(+), 264 deletions(-) create mode 100644 Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift create mode 100644 Sources/PostgresNIO/Deprecated/PostgresConnection+Configuration+Deprecated.swift diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5945d014..66516611 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,14 @@ jobs: - swift:5.8-jammy - swiftlang/swift:nightly-5.9-jammy - swiftlang/swift:nightly-main-jammy + include: + - coverage: true + # https://github.com/apple/swift-package-manager/issues/5853 + - container: swift:5.8-jammy + coverage: false + # https://github.com/apple/swift/issues/65064 + - container: swiftlang/swift:nightly-main-jammy + coverage: false container: ${{ matrix.container }} runs-on: ubuntu-latest env: @@ -29,9 +37,12 @@ jobs: - name: Check out package uses: actions/checkout@v3 - name: Run unit tests with code coverage and Thread Sanitizer - run: swift test --enable-test-discovery --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage + shell: bash + run: | + coverage=$( [[ '${{ matrix.coverage }}' == 'true' ]] && echo -n '--enable-code-coverage' || true ) + swift test --filter=^PostgresNIOTests --sanitize=thread ${coverage} - name: Submit coverage report to Codecov.io - if: ${{ !contains(matrix.container, '5.8') }} + if: ${{ matrix.coverage }} uses: vapor/swift-codecov-action@v0.2 with: cc_flags: 'unittests' @@ -58,6 +69,7 @@ jobs: dbauth: trust container: image: swift:5.8-jammy + volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: LOG_LEVEL: debug @@ -74,10 +86,12 @@ jobs: POSTGRES_HOSTNAME: 'psql-a' POSTGRES_HOSTNAME_A: 'psql-a' POSTGRES_HOSTNAME_B: 'psql-b' + POSTGRES_SOCKET: '/var/run/postgresql/.s.PGSQL.5432' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} services: psql-a: image: ${{ matrix.dbimage }} + volumes: [ 'pgrunshare:/var/run/postgresql' ] env: POSTGRES_USER: 'test_username' POSTGRES_DB: 'test_database' @@ -86,6 +100,7 @@ jobs: POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} psql-b: image: ${{ matrix.dbimage }} + volumes: [ 'pgrunshare:/var/run/postgresql' ] env: POSTGRES_USER: 'test_username' POSTGRES_DB: 'test_database' @@ -134,6 +149,7 @@ jobs: POSTGRES_PASSWORD: 'test_password' POSTGRES_DB: 'postgres' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} + POSTGRES_SOCKET: '/tmp/.s.PGSQL.5432' steps: - name: Select latest available Xcode uses: maxim-lobanov/setup-xcode@v1 diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift new file mode 100644 index 00000000..54eefc90 --- /dev/null +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -0,0 +1,276 @@ +import NIOCore +import NIOPosix // inet_pton() et al. +import NIOSSL + +extension PostgresConnection { + /// A configuration object for a connection + public struct Configuration { + + // MARK: - TLS + + /// The possible modes of operation for TLS encapsulation of a connection. + public struct TLS { + // MARK: Initializers + + /// Do not try to create a TLS connection to the server. + public static var disable: Self = .init(base: .disable) + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, create an insecure connection. + public static func prefer(_ sslContext: NIOSSLContext) -> Self { + self.init(base: .prefer(sslContext)) + } + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, fail the connection creation. + public static func require(_ sslContext: NIOSSLContext) -> Self { + self.init(base: .require(sslContext)) + } + + // MARK: Accessors + + /// Whether TLS will be attempted on the connection (`false` only when mode is ``disable``). + public var isAllowed: Bool { + if case .disable = self.base { return false } + else { return true } + } + + /// Whether TLS will be enforced on the connection (`true` only when mode is ``require(_:)``). + public var isEnforced: Bool { + if case .require(_) = self.base { return true } + else { return false } + } + + /// The `NIOSSLContext` that will be used. `nil` when TLS is disabled. + public var sslContext: NIOSSLContext? { + switch self.base { + case .prefer(let context), .require(let context): return context + case .disable: return nil + } + } + + // MARK: Implementation details + + enum Base { + case disable + case prefer(NIOSSLContext) + case require(NIOSSLContext) + } + let base: Base + private init(base: Base) { self.base = base } + } + + // MARK: - Connection options + + /// Describes options affecting how the underlying connection is made. + public struct Options { + /// A timeout for connection attempts. Defaults to ten seconds. + /// + /// Ignored when using a preexisting communcation channel. (See + /// ``PostgresConnection/Configuration/init(establishedChannel:username:password:database:)``.) + public var connectTimeout: TimeAmount + + /// The server name to use for certificate validation and SNI (Server Name Indication) when TLS is enabled. + /// Defaults to none (but see below). + /// + /// > When set to `nil`: + /// If the connection is made to a server over TCP using + /// ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)``, the given `host` + /// is used, unless it was an IP address string. If it _was_ an IP, or the connection is made by any other + /// method, SNI is disabled. + public var tlsServerName: String? + + /// Whether the connection is required to provide backend key data (internal Postgres stuff). + /// + /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. + /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). + public var requireBackendKeyData: Bool + + /// Create an options structure with default values. + /// + /// Most users should not need to adjust the defaults. + public init() { + self.connectTimeout = .seconds(10) + self.tlsServerName = nil + self.requireBackendKeyData = true + } + } + + // MARK: - Accessors + + /// The hostname to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var host: String? { + if case let .connectTCP(host, _) = self.endpointInfo { return host } + else { return nil } + } + + /// The port to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var port: Int? { + if case let .connectTCP(_, port) = self.endpointInfo { return port } + else { return nil } + } + + /// The socket path to connect to for Unix domain socket connections. + /// + /// Always `nil` for other configurations. + public var unixSocketPath: String? { + if case let .bindUnixDomainSocket(path) = self.endpointInfo { return path } + else { return nil } + } + + /// The `Channel` to use in existing-channel configurations. + /// + /// Always `nil` for other configurations. + public var establishedChannel: Channel? { + if case let .configureChannel(channel) = self.endpointInfo { return channel } + else { return nil } + } + + /// The TLS mode to use for the connection. Valid for all configurations. + /// + /// See ``TLS-swift.struct``. + public var tls: TLS + + /// Options for handling the communication channel. Most users don't need to change these. + /// + /// See ``Options-swift.struct``. + public var options: Options = .init() + + /// The username to connect with. + public var username: String + + /// The password, if any, for the user specified by ``username``. + /// + /// - Warning: `nil` means "no password provided", whereas `""` (the empty string) is a password of zero + /// length; these are not the same thing. + public var password: String? + + /// The name of the database to open. + /// + /// - Note: If set to `nil` or an empty string, the provided ``username`` is used. + public var database: String? + + // MARK: - Initializers + + /// Create a configuration for connecting to a server with a hostname and optional port. + /// + /// This specifies a TCP connection. If you're unsure which kind of connection you want, you almost + /// definitely want this one. + /// + /// - Parameters: + /// - host: The hostname to connect to. + /// - port: The TCP port to connect to (defaults to 5432). + /// - tls: The TLS mode to use. + public init(host: String, port: Int = 5432, username: String, password: String?, database: String?, tls: TLS) { + self.init(endpointInfo: .connectTCP(host: host, port: port), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for connecting to a server through a UNIX domain socket. + /// + /// - Parameters: + /// - path: The filesystem path of the socket to connect to. + /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + public init(unixSocketPath: String, username: String, password: String?, database: String?) { + self.init(endpointInfo: .bindUnixDomainSocket(path: unixSocketPath), tls: .disable, username: username, password: password, database: database) + } + + /// Create a configuration for establishing a connection to a Postgres server over a preestablished + /// `NIOCore/Channel`. + /// + /// This is provided for calling code which wants to manage the underlying connection transport on its + /// own, such as when tunneling a connection through SSH. + /// + /// - Parameters: + /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an + /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). + /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + public init(establishedChannel channel: Channel, username: String, password: String?, database: String?) { + self.init(endpointInfo: .configureChannel(channel), tls: .disable, username: username, password: password, database: database) + } + + // MARK: - Implementation details + + enum EndpointInfo { + case configureChannel(Channel) + case bindUnixDomainSocket(path: String) + case connectTCP(host: String, port: Int) + } + + var endpointInfo: EndpointInfo + + init(endpointInfo: EndpointInfo, tls: TLS, username: String, password: String?, database: String?) { + self.endpointInfo = endpointInfo + self.tls = tls + self.username = username + self.password = password + self.database = database + } + } +} + +// MARK: - Internal config details + +extension PostgresConnection { + /// A configuration object to bring the new ``PostgresConnection.Configuration`` together with + /// the deprecated configuration. + /// + /// TODO: Drop with next major release + struct InternalConfiguration { + enum Connection { + case unresolvedTCP(host: String, port: Int) + case unresolvedUDS(path: String) + case resolved(address: SocketAddress) + case bootstrapped(channel: Channel) + } + + let connection: InternalConfiguration.Connection + let username: String? + let password: String? + let database: String? + var tls: Configuration.TLS + let options: Configuration.Options + } +} + +extension PostgresConnection.InternalConfiguration { + init(_ config: PostgresConnection.Configuration) { + switch config.endpointInfo { + case .connectTCP(let host, let port): self.connection = .unresolvedTCP(host: host, port: port) + case .bindUnixDomainSocket(let path): self.connection = .unresolvedUDS(path: path) + case .configureChannel(let channel): self.connection = .bootstrapped(channel: channel) + } + self.username = config.username + self.password = config.password + self.database = config.database + self.tls = config.tls + self.options = config.options + } + + var serverNameForTLS: String? { + // If a name was explicitly configured, always use it. + if let tlsServerName = self.options.tlsServerName { return tlsServerName } + + // Otherwise, if the connection is TCP and the hostname wasn't an IP (not valid in SNI), use that. + if case .unresolvedTCP(let host, _) = self.connection, !host.isIPAddress() { return host } + + // Otherwise, disable SNI + return nil + } +} + +// originally taken from NIOSSL +private extension String { + func isIPAddress() -> Bool { + // We need some scratch space to let inet_pton write into. + var ipv4Addr = in_addr(), ipv6Addr = in6_addr() // inet_pton() assumes the provided address buffer is non-NULL + + /// N.B.: ``String/withCString(_:)`` is much more efficient than directly passing `self`, especially twice. + return self.withCString { ptr in + inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || inet_pton(AF_INET6, ptr, &ipv6Addr) == 1 + } + } +} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 2061e6bc..c24041c9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -1,11 +1,11 @@ import Atomics import NIOCore +import NIOPosix #if canImport(Network) import NIOTransportServices #endif import NIOSSL import Logging -import NIOPosix /// A Postgres connection. Use it to run queries against a Postgres server. /// @@ -14,108 +14,6 @@ public final class PostgresConnection: @unchecked Sendable { /// A Postgres connection ID public typealias ID = Int - /// A configuration object for a connection - public struct Configuration { - /// A structure to configure the connection's authentication properties - public struct Authentication { - /// The username to connect with. - /// - /// - Default: postgres - public var username: String - - /// The database to open on the server - /// - /// - Default: `nil` - public var database: Optional - - /// The database user's password. - /// - /// - Default: `nil` - public var password: Optional - - public init(username: String, database: String?, password: String?) { - self.username = username - self.database = database - self.password = password - } - } - - public struct TLS { - enum Base { - case disable - case prefer(NIOSSLContext) - case require(NIOSSLContext) - } - - var base: Base - - private init(_ base: Base) { - self.base = base - } - - /// Do not try to create a TLS connection to the server. - public static var disable: Self = Self.init(.disable) - - /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. - /// If the server does not support TLS, create an insecure connection. - public static func prefer(_ sslContext: NIOSSLContext) -> Self { - self.init(.prefer(sslContext)) - } - - /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. - /// If the server does not support TLS, fail the connection creation. - public static func require(_ sslContext: NIOSSLContext) -> Self { - self.init(.require(sslContext)) - } - } - - public struct Connection { - /// The server to connect to - /// - /// - Default: localhost - public var host: String - - /// The server port to connect to. - /// - /// - Default: 5432 - public var port: Int - - /// Require connection to provide `BackendKeyData`. - /// For use with Amazon RDS Proxy, this must be set to false. - /// - /// - Default: true - public var requireBackendKeyData: Bool = true - - /// Specifies a timeout to apply to a connection attempt. - /// - /// - Default: 10 seconds - public var connectTimeout: TimeAmount - - public init(host: String, port: Int = 5432) { - self.host = host - self.port = port - self.connectTimeout = .seconds(10) - } - } - - public var connection: Connection - - /// The authentication properties to send to the Postgres server during startup auth handshake - public var authentication: Authentication - - public var tls: TLS - - public init( - connection: Connection, - authentication: Authentication, - tls: TLS - ) { - self.connection = connection - self.authentication = authentication - self.tls = tls - } - } - /// The connection's underlying channel /// /// This should be private, but it is needed for `PostgresConnection` compatibility. @@ -170,21 +68,21 @@ public final class PostgresConnection: @unchecked Sendable { func start(configuration: InternalConfiguration) -> EventLoopFuture { // 1. configure handlers - var configureSSLCallback: ((Channel) throws -> ())? = nil + let configureSSLCallback: ((Channel) throws -> ())? + switch configuration.tls.base { - case .disable: - break - - case .prefer(let sslContext), .require(let sslContext): + case .prefer(let context), .require(let context): configureSSLCallback = { channel in channel.eventLoop.assertInEventLoop() let sslHandler = try NIOSSLClientHandler( - context: sslContext, - serverHostname: configuration.sslServerHostname + context: context, + serverHostname: configuration.serverNameForTLS ) try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) } + case .disable: + configureSSLCallback = nil } let channelHandler = PostgresChannelHandler( @@ -206,7 +104,7 @@ public final class PostgresConnection: @unchecked Sendable { } let startupFuture: EventLoopFuture - if configuration.authentication == nil { + if configuration.username == nil { startupFuture = eventHandler.readyForStartupFuture } else { startupFuture = eventHandler.authenticateFuture @@ -269,10 +167,17 @@ public final class PostgresConnection: @unchecked Sendable { let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) switch configuration.connection { - case .resolved(let address, _): + case .resolved(let address): connectFuture = bootstrap.connect(to: address) - case .unresolved(let host, let port): + case .unresolvedTCP(let host, let port): connectFuture = bootstrap.connect(host: host, port: port) + case .unresolvedUDS(let path): + connectFuture = bootstrap.connect(unixDomainSocketPath: path) + case .bootstrapped(let channel): + guard channel.isActive else { + return eventLoop.makeFailedFuture(PSQLError.connectionError(underlying: ChannelError.alreadyClosed)) + } + connectFuture = eventLoop.makeSucceededFuture(channel) } return connectFuture.flatMap { channel -> EventLoopFuture in @@ -295,12 +200,12 @@ public final class PostgresConnection: @unchecked Sendable { ) -> NIOClientTCPBootstrapProtocol { #if canImport(Network) if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { - return tsBootstrap.connectTimeout(configuration.connectTimeout) + return tsBootstrap.connectTimeout(configuration.options.connectTimeout) } #endif if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - return nioBootstrap.connectTimeout(configuration.connectTimeout) + return nioBootstrap.connectTimeout(configuration.options.connectTimeout) } fatalError("No matching bootstrap found") @@ -398,19 +303,22 @@ extension PostgresConnection { if let tlsConfiguration = tlsConfiguration { tlsFuture = eventLoop.makeSucceededVoidFuture().flatMapBlocking(onto: .global(qos: .default)) { - try PostgresConnection.Configuration.TLS.require(.init(configuration: tlsConfiguration)) + try .require(.init(configuration: tlsConfiguration)) } } else { tlsFuture = eventLoop.makeSucceededFuture(.disable) } return tlsFuture.flatMap { tls in + var options = PostgresConnection.Configuration.Options() + options.tlsServerName = serverHostname let configuration = PostgresConnection.InternalConfiguration( - connection: .resolved(address: socketAddress, serverName: serverHostname), - connectTimeout: .seconds(10), - authentication: nil, + connection: .resolved(address: socketAddress), + username: nil, + password: nil, + database: nil, tls: tls, - requireBackendKeyData: true + options: options ) return PostgresConnection.connect( @@ -733,66 +641,6 @@ enum CloseTarget { case portal(String) } -extension PostgresConnection.InternalConfiguration { - var sslServerHostname: String? { - switch self.connection { - case .unresolved(let host, _): - guard !host.isIPAddress() else { - return nil - } - return host - case .resolved(_, let serverName): - return serverName - } - } -} - -// copy and pasted from NIOSSL: -private extension String { - func isIPAddress() -> Bool { - // We need some scratch space to let inet_pton write into. - var ipv4Addr = in_addr() - var ipv6Addr = in6_addr() - - return self.withCString { ptr in - return inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || - inet_pton(AF_INET6, ptr, &ipv6Addr) == 1 - } - } -} - -extension PostgresConnection { - /// A configuration object to bring the new ``PostgresConnection.Configuration`` together with - /// the deprecated configuration. - /// - /// TODO: Drop with next major release - struct InternalConfiguration { - enum Connection { - case unresolved(host: String, port: Int) - case resolved(address: SocketAddress, serverName: String?) - } - - var connection: Connection - var connectTimeout: TimeAmount - - var authentication: Configuration.Authentication? - - var tls: Configuration.TLS - - var requireBackendKeyData: Bool - } -} - -extension PostgresConnection.InternalConfiguration { - init(_ config: PostgresConnection.Configuration) { - self.authentication = config.authentication - self.connection = .unresolved(host: config.connection.host, port: config.connection.port) - self.connectTimeout = config.connection.connectTimeout - self.tls = config.tls - self.requireBackendKeyData = config.connection.requireBackendKeyData - } -} - extension EventLoopFuture { func enrichPSQLError(query: PostgresQuery, file: String, line: Int) -> EventLoopFuture { return self.flatMapErrorThrowing { error in diff --git a/Sources/PostgresNIO/Deprecated/PostgresConnection+Configuration+Deprecated.swift b/Sources/PostgresNIO/Deprecated/PostgresConnection+Configuration+Deprecated.swift new file mode 100644 index 00000000..9619c182 --- /dev/null +++ b/Sources/PostgresNIO/Deprecated/PostgresConnection+Configuration+Deprecated.swift @@ -0,0 +1,95 @@ +import NIOCore + +extension PostgresConnection.Configuration { + /// Legacy connection parameters structure. Replaced by ``PostgresConnection/Configuration/host`` etc. + @available(*, deprecated, message: "Use `Configuration.host` etc. instead.") + public struct Connection { + /// See ``PostgresConnection/Configuration/host``. + public var host: String + + /// See ``PostgresConnection/Configuration/port``. + public var port: Int + + /// See ``PostgresConnection/Configuration/Options-swift.struct/requireBackendKeyData``. + public var requireBackendKeyData: Bool = true + + /// See ``PostgresConnection/Configuration/Options-swift.struct/connectTimeout``. + public var connectTimeout: TimeAmount = .seconds(10) + + /// Create a configuration for connecting to a server. + /// + /// - Parameters: + /// - host: The hostname to connect to. + /// - port: The TCP port to connect to (defaults to 5432). + public init(host: String, port: Int = 5432) { + self.host = host + self.port = port + } + } + + /// Legacy authentication parameters structure. Replaced by ``PostgresConnection/Configuration/username`` etc. + @available(*, deprecated, message: "Use `Configuration.username` etc. instead.") + public struct Authentication { + /// See ``PostgresConnection/Configuration/username``. + public var username: String + + /// See ``PostgresConnection/Configuration/password``. + public var password: String? + + /// See ``PostgresConnection/Configuration/database``. + public var database: String? + + public init(username: String, database: String?, password: String?) { + self.username = username + self.database = database + self.password = password + } + } + + /// Accessor for legacy connection parameters. Replaced by ``PostgresConnection/Configuration/host`` etc. + @available(*, deprecated, message: "Use `Configuration.host` etc. instead.") + public var connection: Connection { + get { + var conn: Connection + switch self.endpointInfo { + case .connectTCP(let host, let port): + conn = .init(host: host, port: port) + case .bindUnixDomainSocket(_), .configureChannel(_): + conn = .init(host: "!invalid!", port: 0) // best we can do, really + } + conn.requireBackendKeyData = self.options.requireBackendKeyData + conn.connectTimeout = self.options.connectTimeout + return conn + } + set { + self.endpointInfo = .connectTCP(host: newValue.host, port: newValue.port) + self.options.connectTimeout = newValue.connectTimeout + self.options.requireBackendKeyData = newValue.requireBackendKeyData + } + } + + @available(*, deprecated, message: "Use `Configuration.username` etc. instead.") + public var authentication: Authentication { + get { + .init(username: self.username, database: self.database, password: self.password) + } + set { + self.username = newValue.username + self.password = newValue.password + self.database = newValue.database + } + } + + /// Legacy initializer. + /// Replaced by ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)`` etc. + @available(*, deprecated, message: "Use `init(host:port:username:password:database:tls:)` instead.") + public init(connection: Connection, authentication: Authentication, tls: TLS) { + self.init( + host: connection.host, port: connection.port, + username: authentication.username, password: authentication.password, database: authentication.database, + tls: tls + ) + self.options.connectTimeout = connection.connectTimeout + self.options.requireBackendKeyData = connection.requireBackendKeyData + } +} diff --git a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift index e44d77e5..1de0f394 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PostgresCodable.swift @@ -1,4 +1,5 @@ import NIOCore +import NIOFoundationCompat import struct Foundation.UUID import typealias Foundation.uuid_t import NIOFoundationCompat diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index ec02cd2c..5b9f4240 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -32,7 +32,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { logger: Logger, configureSSLCallback: ((Channel) throws -> Void)?) { - self.state = ConnectionStateMachine(requireBackendKeyData: configuration.requireBackendKeyData) + self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -284,11 +284,11 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .provideAuthenticationContext: context.fireUserInboundEventTriggered(PSQLEvent.readyForStartup) - if let authentication = self.configuration.authentication { + if let username = self.configuration.username { let authContext = AuthContext( - username: authentication.username, - password: authentication.password, - database: authentication.database + username: username, + password: self.configuration.password, + database: self.configuration.database ) let action = self.state.provideAuthenticationContext(authContext) return self.run(action, with: context) @@ -517,16 +517,6 @@ extension PostgresChannelHandler: PSQLRowsDataSource { } } -extension PostgresConnection.Configuration.Authentication { - func toAuthContext() -> AuthContext { - AuthContext( - username: self.username, - password: self.password, - database: self.database - ) - } -} - extension AuthContext { func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters { PostgresFrontendMessage.Startup.Parameters( @@ -577,12 +567,12 @@ private extension Insecure.MD5.Digest { extension ConnectionStateMachine.TLSConfiguration { fileprivate init(_ tls: PostgresConnection.Configuration.TLS) { - switch tls.base { - case .disable: + switch (tls.isAllowed, tls.isEnforced) { + case (false, _): self = .disable - case .require: + case (true, true): self = .require - case .prefer: + case (true, false): self = .prefer } } diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 68291eac..3aa1a24f 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -18,7 +18,7 @@ public protocol PostgresEncodable { } /// A type that can encode itself to a postgres wire binary representation. It enforces that the -/// ``PostgresEncodable/encode(into:context:)`` does not throw. This allows users +/// ``PostgresEncodable/encode(into:context:)-1jkcp`` does not throw. This allows users /// to create ``PostgresQuery``s using the `ExpressibleByStringInterpolation` without /// having to spell `try`. public protocol PostgresNonThrowingEncodable: PostgresEncodable { diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 1c020411..5fc86b74 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,4 +1,10 @@ -#if !BUILDING_DOCC +#if compiler(>=5.8) + +@_documentation(visibility: internal) @_exported import NIO +@_documentation(visibility: internal) @_exported import NIOSSL +@_documentation(visibility: internal) @_exported import struct Logging.Logger + +#elseif !BUILDING_DOCC // TODO: Remove this with the next major release! @_exported import NIO diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 5debde90..4b2b9950 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -23,15 +23,11 @@ final class IntegrationTests: XCTestCase { try XCTSkipIf(env("POSTGRES_HOST_AUTH_METHOD") == "trust") let config = PostgresConnection.Configuration( - connection: .init( - host: env("POSTGRES_HOSTNAME") ?? "localhost", - port: 5432 - ), - authentication: .init( - username: env("POSTGRES_USER") ?? "test_username", - database: env("POSTGRES_DB") ?? "test_database", - password: "wrong_password" - ), + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, + username: env("POSTGRES_USER") ?? "test_username", + password: "wrong_password", + database: env("POSTGRES_DB") ?? "test_database", tls: .disable ) diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 114ae2bc..348e6eb6 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -31,6 +31,18 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) XCTAssertNoThrow(try conn?.close().wait()) } + + func testConnectUDSAndClose() throws { + try XCTSkipUnless(env("POSTGRES_SOCKET") != nil) + let conn = try PostgresConnection.testUDS(on: eventLoop).wait() + try conn.close().wait() + } + + func testConnectEstablishedChannelAndClose() throws { + let channel = try ClientBootstrap(group: self.group).connect(to: PostgresConnection.address()).wait() + let conn = try PostgresConnection.testChannel(channel, on: self.eventLoop).wait() + try conn.close().wait() + } func testSimpleQueryVersion() { var conn: PostgresConnection? @@ -42,6 +54,27 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(try rows?.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) } + func testSimpleQueryVersionUsingUDS() throws { + try XCTSkipUnless(env("POSTGRES_SOCKET") != nil) + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.testUDS(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + var rows: [PostgresRow]? + XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) + XCTAssertEqual(rows?.count, 1) + XCTAssertEqual(try rows?.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) + } + + func testSimpleQueryVersionUsingEstablishedChannel() throws { + let channel = try ClientBootstrap(group: self.group).connect(to: PostgresConnection.address()).wait() + let conn = try PostgresConnection.testChannel(channel, on: self.eventLoop).wait() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + let rows = try conn.simpleQuery("SELECT version()").wait() + XCTAssertEqual(rows.count, 1) + XCTAssertEqual(try rows.first?.decode(String.self, context: .default).contains("PostgreSQL"), true) + } + func testQueryVersion() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -744,19 +777,13 @@ final class PostgresNIOTests: XCTestCase { let logger = Logger(label: "test") let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration()) let config = PostgresConnection.Configuration( - connection: .init( - host: "elmer.db.elephantsql.com", - port: 5432 - ), - authentication: .init( - username: "uymgphwj", - database: "uymgphwj", - password: "7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA" - ), + host: "elmer.db.elephantsql.com", + port: 5432, + username: "uymgphwj", + password: "7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA", + database: "uymgphwj", tls: .require(sslContext) ) - - XCTAssertNoThrow(conn = try PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } var rows: [PostgresRow]? diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index faa19c42..b1788110 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -10,9 +10,9 @@ import Glibc extension PostgresConnection { static func address() throws -> SocketAddress { - try .makeAddressResolvingHost(env("POSTGRES_HOSTNAME") ?? "localhost", port: 5432) + try .makeAddressResolvingHost(env("POSTGRES_HOSTNAME") ?? "localhost", port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432) } - + @available(*, deprecated, message: "Test deprecated functionality") static func testUnauthenticated(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { var logger = Logger(label: "postgres.connection.test") @@ -29,20 +29,44 @@ extension PostgresConnection { logger.logLevel = logLevel let config = PostgresConnection.Configuration( - connection: .init( - host: env("POSTGRES_HOSTNAME") ?? "localhost", - port: 5432 - ), - authentication: .init( - username: env("POSTGRES_USER") ?? "test_username", - database: env("POSTGRES_DB") ?? "test_database", - password: env("POSTGRES_PASSWORD") ?? "test_password" - ), + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database", tls: .disable ) return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } + + static func testUDS(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { + var logger = Logger(label: "postgres.connection.test") + logger.logLevel = logLevel + + let config = PostgresConnection.Configuration( + unixSocketPath: env("POSTGRES_SOCKET") ?? "/tmp/.s.PGSQL.\(env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432)", + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database" + ) + + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) + } + + static func testChannel(_ channel: Channel, on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { + var logger = Logger(label: "postgres.connection.test") + logger.logLevel = logLevel + + let config = PostgresConnection.Configuration( + establishedChannel: channel, + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database" + ) + + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) + } } extension Logger { diff --git a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift index 2d50cb0f..2a58d4f6 100644 --- a/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLConnectionTests.swift @@ -22,8 +22,8 @@ class PSQLConnectionTests: XCTestCase { } let config = PostgresConnection.Configuration( - connection: .init(host: "127.0.0.1", port: port), - authentication: .init(username: "postgres", database: "postgres", password: "abc123"), + host: "127.0.0.1", port: port, + username: "postgres", password: "abc123", database: "postgres", tls: .disable ) diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 298595c7..9e3bbefa 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -26,8 +26,8 @@ class PostgresChannelHandlerTests: XCTestCase { return XCTFail("Unexpected message") } - XCTAssertEqual(startup.parameters.user, config.authentication?.username) - XCTAssertEqual(startup.parameters.database, config.authentication?.database) + XCTAssertEqual(startup.parameters.user, config.username) + XCTAssertEqual(startup.parameters.database, config.database) XCTAssertEqual(startup.parameters.options, nil) XCTAssertEqual(startup.parameters.replication, .false) @@ -73,14 +73,13 @@ class PostgresChannelHandlerTests: XCTestCase { return XCTFail("Unexpected message") } - XCTAssertEqual(startupMessage.parameters.user, config.authentication?.username) - XCTAssertEqual(startupMessage.parameters.database, config.authentication?.database) + XCTAssertEqual(startupMessage.parameters.user, config.username) + XCTAssertEqual(startupMessage.parameters.database, config.database) XCTAssertEqual(startupMessage.parameters.replication, .false) } - func testSSLUnsupportedClosesConnection() { - var config = self.testConnectionConfiguration() - XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) + func testSSLUnsupportedClosesConnection() throws { + let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) let handler = PostgresChannelHandler(configuration: config) { channel in XCTFail("This callback should never be exectuded") @@ -92,14 +91,14 @@ class PostgresChannelHandlerTests: XCTestCase { handler ]) let eventHandler = TestEventHandler() - XCTAssertNoThrow(try embedded.pipeline.addHandler(eventHandler, position: .last).wait()) + try embedded.pipeline.addHandler(eventHandler, position: .last).wait() - XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) + embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil) XCTAssertTrue(embedded.isActive) // read the ssl request message XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .sslRequest(.init())) - XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.sslUnsupported)) + try embedded.writeInbound(PostgresBackendMessage.sslUnsupported) // the event handler should have seen an error XCTAssertEqual(eventHandler.errors.count, 1) @@ -113,9 +112,9 @@ class PostgresChannelHandlerTests: XCTestCase { func testRunAuthenticateMD5Password() { let config = self.testConnectionConfiguration() let authContext = AuthContext( - username: config.authentication?.username ?? "something wrong", - password: config.authentication?.password, - database: config.authentication?.database + username: config.username ?? "something wrong", + password: config.password, + database: config.database ) let state = ConnectionStateMachine(.waitingToStartAuthentication) let handler = PostgresChannelHandler(configuration: config, state: state, configureSSLCallback: nil) @@ -138,13 +137,11 @@ class PostgresChannelHandlerTests: XCTestCase { func testRunAuthenticateCleartext() { let password = "postgres" - var config = self.testConnectionConfiguration() - config.authentication?.password = password - + let config = self.testConnectionConfiguration(password: password) let authContext = AuthContext( - username: config.authentication?.username ?? "something wrong", - password: config.authentication?.password, - database: config.authentication?.database + username: config.username ?? "something wrong", + password: config.password, + database: config.database ) let state = ConnectionStateMachine(.waitingToStartAuthentication) let handler = PostgresChannelHandler(configuration: config, state: state, configureSSLCallback: nil) @@ -177,18 +174,17 @@ class PostgresChannelHandlerTests: XCTestCase { connectTimeout: TimeAmount = .seconds(10), requireBackendKeyData: Bool = true ) -> PostgresConnection.InternalConfiguration { - let authentication = PostgresConnection.Configuration.Authentication( - username: username, - database: database, - password: password - ) + var options = PostgresConnection.Configuration.Options() + options.connectTimeout = connectTimeout + options.requireBackendKeyData = requireBackendKeyData return PostgresConnection.InternalConfiguration( - connection: .unresolved(host: host, port: port), - connectTimeout: connectTimeout, - authentication: authentication, + connection: .unresolvedTCP(host: host, port: port), + username: username, + password: password, + database: database, tls: tls, - requireBackendKeyData: requireBackendKeyData + options: options ) } } From 004d92aa8abd3db866c6de4e77a592f9d90a3be0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 20 Apr 2023 14:44:26 +0200 Subject: [PATCH 141/292] Update Readme (#344) --- README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index b82200b4..2723eb4b 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![Team Chat](https://img.shields.io/discord/431917998102675485.svg)][Team Chat] [![MIT License](http://img.shields.io/badge/license-MIT-brightgreen.svg)][MIT License] [![Continuous Integration](https://github.com/vapor/postgres-nio/actions/workflows/test.yml/badge.svg)][Continuous Integration] -[![Swift 5.5](http://img.shields.io/badge/swift-5.5-brightgreen.svg)][Swift 5.5] +[![Swift 5.6](http://img.shields.io/badge/swift-5.6-brightgreen.svg)][Swift 5.6]

@@ -19,6 +19,7 @@ Features: - Integrated with the Swift server ecosystem, including use of [SwiftLog]. - Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) - Support for `Network.framework` when available (e.g. on Apple platforms) +- Supports running on Unix Domain Sockets PostgresNIO does not provide a `ConnectionPool` as of today, but this is a [feature high on our list](https://github.com/vapor/postgres-nio/issues/256). If you need a `ConnectionPool` today, please have a look at Vapor's [PostgresKit]. @@ -35,7 +36,7 @@ Add `PostgresNIO` as dependency to your `Package.swift`: ```swift dependencies: [ - .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.8.0"), + .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.14.0"), ... ] ``` @@ -79,7 +80,7 @@ import NIOPosix let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) // Much later -try eventLoopGroup.syncShutdown() +try await eventLoopGroup.shutdownGracefully() ``` A [`Logger`] is also required. @@ -124,7 +125,7 @@ let connection = try await PostgresConnection.connect( try await connection.close() // Shutdown the EventLoopGroup, once all connections are closed. -try eventLoopGroup.syncShutdown() +try await eventLoopGroup.shutdownGracefully() ``` #### Querying @@ -148,7 +149,7 @@ for try await row in rows { However, in most cases it is much easier to request a row's fields as a set of Swift types: ```swift -for try await (id, username, birthday) in rows.decode((Int, String, Date).self, context: .default) { +for try await (id, username, birthday) in rows.decode((Int, String, Date).self) { // do something with the datatypes. } ``` @@ -191,7 +192,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.5]: https://swift.org +[Swift 5.6]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/ From 263d0712461fd3dc09c289fffb1e37be3c7c38c3 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 20 Apr 2023 15:09:58 +0200 Subject: [PATCH 142/292] Readme update (#345) --- README.md | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 2723eb4b..51e0b8c5 100644 --- a/README.md +++ b/README.md @@ -58,16 +58,12 @@ To create a connection, first create a connection configuration object: import PostgresNIO let config = PostgresConnection.Configuration( - connection: .init( - host: "localhost", - port: 5432 - ), - authentication: .init( - username: "my_username", - database: "my_database", - password: "my_password" - ), - tls: .disable + host: "localhost", + port: 5432, + username: "my_username", + password: "my_password", + database: "my_database", + tls: .disable ) ``` @@ -102,16 +98,12 @@ let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) let logger = Logger(label: "postgres-logger") let config = PostgresConnection.Configuration( - connection: .init( - host: "localhost", - port: 5432 - ), - authentication: .init( - username: "my_username", - database: "my_database", - password: "my_password" - ), - tls: .disable + host: "localhost", + port: 5432, + username: "my_username", + password: "my_password", + database: "my_database", + tls: .disable ) let connection = try await PostgresConnection.connect( From 369a9ee024439555bf7b7464750f55cc834d2d2e Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 21 Apr 2023 08:16:38 +0200 Subject: [PATCH 143/292] Add PostgresRow decoding tests, reformat Package.swift (#346) --- Package.swift | 53 +++++++------ .../New/PostgresRowTests.swift | 76 +++++++++++++++++++ 2 files changed, 107 insertions(+), 22 deletions(-) diff --git a/Package.swift b/Package.swift index afd064a1..5aa9121e 100644 --- a/Package.swift +++ b/Package.swift @@ -22,27 +22,36 @@ let package = Package( .package(url: "https://github.com/apple/swift-log.git", from: "1.5.2"), ], targets: [ - .target(name: "PostgresNIO", dependencies: [ - .product(name: "Atomics", package: "swift-atomics"), - .product(name: "Crypto", package: "swift-crypto"), - .product(name: "Logging", package: "swift-log"), - .product(name: "Metrics", package: "swift-metrics"), - .product(name: "NIO", package: "swift-nio"), - .product(name: "NIOCore", package: "swift-nio"), - .product(name: "NIOPosix", package: "swift-nio"), - .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), - .product(name: "NIOTLS", package: "swift-nio"), - .product(name: "NIOSSL", package: "swift-nio-ssl"), - .product(name: "NIOFoundationCompat", package: "swift-nio"), - ]), - .testTarget(name: "PostgresNIOTests", dependencies: [ - .target(name: "PostgresNIO"), - .product(name: "NIOEmbedded", package: "swift-nio"), - .product(name: "NIOTestUtils", package: "swift-nio"), - ]), - .testTarget(name: "IntegrationTests", dependencies: [ - .target(name: "PostgresNIO"), - .product(name: "NIOTestUtils", package: "swift-nio"), - ]), + .target( + name: "PostgresNIO", + dependencies: [ + .product(name: "Atomics", package: "swift-atomics"), + .product(name: "Crypto", package: "swift-crypto"), + .product(name: "Logging", package: "swift-log"), + .product(name: "Metrics", package: "swift-metrics"), + .product(name: "NIO", package: "swift-nio"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), + .product(name: "NIOTLS", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "NIOFoundationCompat", package: "swift-nio"), + ] + ), + .testTarget( + name: "PostgresNIOTests", + dependencies: [ + .target(name: "PostgresNIO"), + .product(name: "NIOEmbedded", package: "swift-nio"), + .product(name: "NIOTestUtils", package: "swift-nio"), + ] + ), + .testTarget( + name: "IntegrationTests", + dependencies: [ + .target(name: "PostgresNIO"), + .product(name: "NIOTestUtils", package: "swift-nio"), + ] + ), ] ) diff --git a/Tests/PostgresNIOTests/New/PostgresRowTests.swift b/Tests/PostgresNIOTests/New/PostgresRowTests.swift index c84b9baa..7be5a58a 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowTests.swift @@ -122,4 +122,80 @@ final class PostgresRowTests: XCTestCase { XCTAssertEqual(randomAccessRow["id"], PostgresCell(bytes: nil, dataType: .uuid, format: .binary, columnName: "id", columnIndex: 0)) XCTAssertEqual(randomAccessRow["name"], PostgresCell(bytes: ByteBuffer(string: "Hello world!"), dataType: .text, format: .binary, columnName: "name", columnIndex: 1)) } + + func testDecoding() { + let rowDescription = [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .uuid, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ), + RowDescription.Column( + name: "name", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .text, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + + let row = PostgresRow( + data: .makeTestDataRow(nil, ByteBuffer(string: "Hello world!")), + lookupTable: ["id": 0, "name": 1], + columns: rowDescription + ) + + var result: (UUID?, String)? + XCTAssertNoThrow(result = try row.decode((UUID?, String).self)) + XCTAssertEqual(result?.0, .some(.none)) + XCTAssertEqual(result?.1, "Hello world!") + } + + func testDecodingTypeMismatch() { + let rowDescription = [ + RowDescription.Column( + name: "id", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .uuid, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ), + RowDescription.Column( + name: "name", + tableOID: 1, + columnAttributeNumber: 1, + dataType: .int8, + dataTypeSize: 0, + dataTypeModifier: 0, + format: .binary + ) + ] + + let row = PostgresRow( + data: .makeTestDataRow(nil, ByteBuffer(integer: 123)), + lookupTable: ["id": 0, "name": 1], + columns: rowDescription + ) + + XCTAssertThrowsError(try row.decode((UUID?, String).self)) { error in + guard let psqlError = error as? PostgresDecodingError else { return XCTFail("Unexpected error type") } + + XCTAssertEqual(psqlError.columnName, "name") + XCTAssertEqual(psqlError.columnIndex, 1) + XCTAssertEqual(psqlError.line, #line - 5) + XCTAssertEqual(psqlError.file, #file) + XCTAssertEqual(psqlError.postgresData, ByteBuffer(integer: 123)) + XCTAssertEqual(psqlError.postgresFormat, .binary) + XCTAssertEqual(psqlError.postgresType, .int8) + XCTAssert(psqlError.targetType == String.self) + } + } } From 1516e0c5868b1bfe580f5d23ced3e44c681153f3 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 21 Apr 2023 14:26:14 +0200 Subject: [PATCH 144/292] Use #fileID and #filePath instead of #file (#348) --- Sources/PostgresNIO/Data/PostgresRow.swift | 4 +- .../Deprecated/PostgresData+UInt.swift | 2 +- .../New/Extensions/Logging+PSQL.swift | 14 ++--- .../New/PostgresBackendMessageDecoder.swift | 18 +++--- Sources/PostgresNIO/New/PostgresCell.swift | 4 +- .../New/PostgresRow-multi-decode.swift | 60 +++++++++--------- .../PostgresRowSequence-multi-decode.swift | 62 ++++++++++--------- Tests/IntegrationTests/AsyncTests.swift | 2 +- Tests/IntegrationTests/PerformanceTests.swift | 2 +- .../PSQLFrontendMessageDecoder.swift | 2 +- .../New/PostgresCellTests.swift | 2 +- .../New/PostgresCodableTests.swift | 2 +- .../New/PostgresRowTests.swift | 2 +- dev/generate-postgresrow-multi-decode.sh | 4 +- ...nerate-postgresrowsequence-multi-decode.sh | 4 +- 15 files changed, 93 insertions(+), 91 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index af7758f4..e3aea692 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -217,7 +217,7 @@ extension PostgresRandomAccessRow { column: String, as type: T.Type, context: PostgresDecodingContext, - file: String = #file, line: Int = #line + file: String = #fileID, line: Int = #line ) throws -> T { guard let index = self.lookupTable[column] else { fatalError(#"A column "\#(column)" does not exist."#) @@ -237,7 +237,7 @@ extension PostgresRandomAccessRow { column index: Int, as type: T.Type, context: PostgresDecodingContext, - file: String = #file, line: Int = #line + file: String = #fileID, line: Int = #line ) throws -> T { precondition(index < self.columns.count) diff --git a/Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift b/Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift index ab3e493f..1f74a26c 100644 --- a/Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift @@ -1,6 +1,6 @@ private func warn( _ old: Any.Type, mustBeConvertedTo new: Any.Type, - file: StaticString = #file, line: UInt = #line + file: StaticString = #filePath, line: UInt = #line ) { assertionFailure(""" Integer conversion unsafe. diff --git a/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift b/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift index ed83e84d..97c729f0 100644 --- a/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift @@ -86,7 +86,7 @@ extension Logger { func trace(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, - file: String = #file, function: String = #function, line: UInt = #line) { + file: String = #fileID, function: String = #function, line: UInt = #line) { self.log(level: .trace, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) } @@ -95,7 +95,7 @@ extension Logger { func debug(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, - file: String = #file, function: String = #function, line: UInt = #line) { + file: String = #fileID, function: String = #function, line: UInt = #line) { self.log(level: .debug, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) } @@ -104,7 +104,7 @@ extension Logger { func info(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, - file: String = #file, function: String = #function, line: UInt = #line) { + file: String = #fileID, function: String = #function, line: UInt = #line) { self.log(level: .info, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) } @@ -113,7 +113,7 @@ extension Logger { func notice(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, - file: String = #file, function: String = #function, line: UInt = #line) { + file: String = #fileID, function: String = #function, line: UInt = #line) { self.log(level: .notice, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) } @@ -122,7 +122,7 @@ extension Logger { func warning(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, - file: String = #file, function: String = #function, line: UInt = #line) { + file: String = #fileID, function: String = #function, line: UInt = #line) { self.log(level: .warning, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) } @@ -131,7 +131,7 @@ extension Logger { func error(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, - file: String = #file, function: String = #function, line: UInt = #line) { + file: String = #fileID, function: String = #function, line: UInt = #line) { self.log(level: .error, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) } @@ -140,7 +140,7 @@ extension Logger { func critical(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, - file: String = #file, function: String = #function, line: UInt = #line) { + file: String = #fileID, function: String = #function, line: UInt = #line) { self.log(level: .critical, message(), metadata: metadata().representation, source: source(), file: file, function: function, line: line) } } diff --git a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift index 4e3b630e..ee7e1b84 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -123,7 +123,7 @@ struct PostgresMessageDecodingError: Error { static func unknownMessageIDReceived( messageID: UInt8, messageBytes: ByteBuffer, - file: String = #file, + file: String = #fileID, line: Int = #line) -> Self { var byteBuffer = messageBytes @@ -152,7 +152,7 @@ struct PSQLPartialDecodingError: Error { static func valueNotRawRepresentable( value: Target.RawValue, asType: Target.Type, - file: String = #file, + file: String = #fileID, line: Int = #line) -> Self { return PSQLPartialDecodingError( @@ -160,31 +160,31 @@ struct PSQLPartialDecodingError: Error { file: file, line: line) } - static func unexpectedValue(value: Any, file: String = #file, line: Int = #line) -> Self { + static func unexpectedValue(value: Any, file: String = #fileID, line: Int = #line) -> Self { return PSQLPartialDecodingError( description: "Value '\(value)' is not expected.", file: file, line: line) } - static func expectedAtLeastNRemainingBytes(_ expected: Int, actual: Int, file: String = #file, line: Int = #line) -> Self { + static func expectedAtLeastNRemainingBytes(_ expected: Int, actual: Int, file: String = #fileID, line: Int = #line) -> Self { return PSQLPartialDecodingError( description: "Expected at least '\(expected)' remaining bytes. But only found \(actual).", file: file, line: line) } - static func expectedExactlyNRemainingBytes(_ expected: Int, actual: Int, file: String = #file, line: Int = #line) -> Self { + static func expectedExactlyNRemainingBytes(_ expected: Int, actual: Int, file: String = #fileID, line: Int = #line) -> Self { return PSQLPartialDecodingError( description: "Expected exactly '\(expected)' remaining bytes. But found \(actual).", file: file, line: line) } - static func fieldNotDecodable(type: Any.Type, file: String = #file, line: Int = #line) -> Self { + static func fieldNotDecodable(type: Any.Type, file: String = #fileID, line: Int = #line) -> Self { return PSQLPartialDecodingError( description: "Could not read '\(type)' from ByteBuffer.", file: file, line: line) } - static func integerMustBePositiveOrNull(_ actual: Number, file: String = #file, line: Int = #line) -> Self { + static func integerMustBePositiveOrNull(_ actual: Number, file: String = #fileID, line: Int = #line) -> Self { return PSQLPartialDecodingError( description: "Expected the integer to be positive or null, but got \(actual).", file: file, line: line) @@ -192,14 +192,14 @@ struct PSQLPartialDecodingError: Error { } extension ByteBuffer { - mutating func throwingReadInteger(as: I.Type, file: String = #file, line: Int = #line) throws -> I { + mutating func throwingReadInteger(as: I.Type, file: String = #fileID, line: Int = #line) throws -> I { guard let result = self.readInteger(endianness: .big, as: I.self) else { throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(MemoryLayout.size, actual: self.readableBytes, file: file, line: line) } return result } - mutating func throwingMoveReaderIndex(forwardBy offset: Int, file: String = #file, line: Int = #line) throws { + mutating func throwingMoveReaderIndex(forwardBy offset: Int, file: String = #fileID, line: Int = #line) throws { guard self.readSlice(length: offset) != nil else { throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(offset, actual: self.readableBytes, file: file, line: line) } diff --git a/Sources/PostgresNIO/New/PostgresCell.swift b/Sources/PostgresNIO/New/PostgresCell.swift index d3cf8d4e..7598a31a 100644 --- a/Sources/PostgresNIO/New/PostgresCell.swift +++ b/Sources/PostgresNIO/New/PostgresCell.swift @@ -43,7 +43,7 @@ extension PostgresCell { public func decode( _: T.Type, context: PostgresDecodingContext, - file: String = #file, + file: String = #fileID, line: Int = #line ) throws -> T { var copy = self.bytes @@ -80,7 +80,7 @@ extension PostgresCell { @inlinable public func decode( _: T.Type, - file: String = #file, + file: String = #fileID, line: Int = #line ) throws -> T { try self.decode(T.self, context: .default, file: file, line: line) diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift index 4fe396ec..cb62c325 100644 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift @@ -3,7 +3,7 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0) { + public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0) { precondition(self.columns.count >= 1) let columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -33,13 +33,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #file, line: Int = #line) throws -> (T0) { + public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) throws -> (T0) { try self.decode(T0.self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1) { + public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { precondition(self.columns.count >= 2) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -75,13 +75,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #file, line: Int = #line) throws -> (T0, T1) { + public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { try self.decode((T0, T1).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { + public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { precondition(self.columns.count >= 3) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -123,13 +123,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) { + public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { try self.decode((T0, T1, T2).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { + public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { precondition(self.columns.count >= 4) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -177,13 +177,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) { + public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { try self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { + public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { precondition(self.columns.count >= 5) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -237,13 +237,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { + public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { try self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { precondition(self.columns.count >= 6) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -303,13 +303,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { try self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { precondition(self.columns.count >= 7) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -375,13 +375,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { try self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { precondition(self.columns.count >= 8) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -453,13 +453,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { precondition(self.columns.count >= 9) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -537,13 +537,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { precondition(self.columns.count >= 10) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -627,13 +627,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { precondition(self.columns.count >= 11) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -723,13 +723,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { precondition(self.columns.count >= 12) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -825,13 +825,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { precondition(self.columns.count >= 13) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -933,13 +933,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { precondition(self.columns.count >= 14) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -1047,13 +1047,13 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { precondition(self.columns.count >= 15) var columnIndex = 0 var cellIterator = self.data.makeIterator() @@ -1167,7 +1167,7 @@ extension PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) } } diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift index ff212d0a..53d9a7ea 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift @@ -1,9 +1,10 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh +#if canImport(_Concurrency) extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode(T0.self, context: context, file: file, line: line) } @@ -11,13 +12,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode(T0.self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1).self, context: context, file: file, line: line) } @@ -25,13 +26,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2).self, context: context, file: file, line: line) } @@ -39,13 +40,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3).self, context: context, file: file, line: line) } @@ -53,13 +54,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4).self, context: context, file: file, line: line) } @@ -67,13 +68,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5).self, context: context, file: file, line: line) } @@ -81,13 +82,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6).self, context: context, file: file, line: line) } @@ -95,13 +96,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: context, file: file, line: line) } @@ -109,13 +110,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: context, file: file, line: line) } @@ -123,13 +124,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: context, file: file, line: line) } @@ -137,13 +138,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: context, file: file, line: line) } @@ -151,13 +152,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: context, file: file, line: line) } @@ -165,13 +166,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: context, file: file, line: line) } @@ -179,13 +180,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: context, file: file, line: line) } @@ -193,13 +194,13 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) } @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.map { row in try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: context, file: file, line: line) } @@ -207,7 +208,8 @@ extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence { + public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) } } +#endif diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 6857e461..6fa47c2a 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -146,7 +146,7 @@ extension XCTestCase { func withTestConnection( on eventLoop: EventLoop, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line, _ closure: (PostgresConnection) async throws -> Result ) async throws -> Result { diff --git a/Tests/IntegrationTests/PerformanceTests.swift b/Tests/IntegrationTests/PerformanceTests.swift index 5d30db5e..6f730560 100644 --- a/Tests/IntegrationTests/PerformanceTests.swift +++ b/Tests/IntegrationTests/PerformanceTests.swift @@ -273,7 +273,7 @@ private func prepareTableToMeasureSelectPerformance( schema: String, fixtureData: [PostgresData], on eventLoop: EventLoop, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line ) throws { XCTAssertEqual(rowCount % batchSize, 0, "`rowCount` must be a multiple of `batchSize`", file: (file), line: line) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index fc3f8858..311c41bd 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -157,7 +157,7 @@ extension PostgresMessageDecodingError { static func unknownStartupCodeReceived( code: UInt32, messageBytes: ByteBuffer, - file: String = #file, + file: String = #fileID, line: Int = #line) -> Self { var byteBuffer = messageBytes diff --git a/Tests/PostgresNIOTests/New/PostgresCellTests.swift b/Tests/PostgresNIOTests/New/PostgresCellTests.swift index 7df5ac9f..6458d063 100644 --- a/Tests/PostgresNIOTests/New/PostgresCellTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCellTests.swift @@ -45,7 +45,7 @@ final class PostgresCellTests: XCTestCase { return XCTFail("Unexpected error") } - XCTAssertEqual(error.file, #file) + XCTAssertEqual(error.file, #fileID) XCTAssertEqual(error.line, #line - 6) XCTAssertEqual(error.code, .typeMismatch) XCTAssertEqual(error.columnName, "hello") diff --git a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift index c1ef041e..94a0253b 100644 --- a/Tests/PostgresNIOTests/New/PostgresCodableTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCodableTests.swift @@ -55,7 +55,7 @@ final class PostgresCodableTests: XCTestCase { XCTAssertThrowsError(try row.decode(String.self, context: .default)) { XCTAssertEqual(($0 as? PostgresDecodingError)?.line, #line - 1) - XCTAssertEqual(($0 as? PostgresDecodingError)?.file, #file) + XCTAssertEqual(($0 as? PostgresDecodingError)?.file, #fileID) XCTAssertEqual(($0 as? PostgresDecodingError)?.code, .missingData) XCTAssert(($0 as? PostgresDecodingError)?.targetType == String.self) diff --git a/Tests/PostgresNIOTests/New/PostgresRowTests.swift b/Tests/PostgresNIOTests/New/PostgresRowTests.swift index 7be5a58a..7aa4c7e6 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowTests.swift @@ -191,7 +191,7 @@ final class PostgresRowTests: XCTestCase { XCTAssertEqual(psqlError.columnName, "name") XCTAssertEqual(psqlError.columnIndex, 1) XCTAssertEqual(psqlError.line, #line - 5) - XCTAssertEqual(psqlError.file, #file) + XCTAssertEqual(psqlError.file, #fileID) XCTAssertEqual(psqlError.postgresData, ByteBuffer(integer: 123)) XCTAssertEqual(psqlError.postgresFormat, .binary) XCTAssertEqual(psqlError.postgresType, .int8) diff --git a/dev/generate-postgresrow-multi-decode.sh b/dev/generate-postgresrow-multi-decode.sh index 64a37417..e641ed8d 100755 --- a/dev/generate-postgresrow-multi-decode.sh +++ b/dev/generate-postgresrow-multi-decode.sh @@ -22,7 +22,7 @@ function genWithContextParameter() { for ((n = 1; n<$how_many; n +=1)); do echo -n ", T$(($n))" done - echo -n ").Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) throws" + echo -n ").Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws" echo -n " -> (T0" for ((n = 1; n<$how_many; n +=1)); do @@ -97,7 +97,7 @@ function genWithoutContextParameter() { for ((n = 1; n<$how_many; n +=1)); do echo -n ", T$(($n))" done - echo -n ").Type, file: String = #file, line: Int = #line) throws" + echo -n ").Type, file: String = #fileID, line: Int = #line) throws" echo -n " -> (T0" for ((n = 1; n<$how_many; n +=1)); do diff --git a/dev/generate-postgresrowsequence-multi-decode.sh b/dev/generate-postgresrowsequence-multi-decode.sh index 126f2a61..8317149b 100755 --- a/dev/generate-postgresrowsequence-multi-decode.sh +++ b/dev/generate-postgresrowsequence-multi-decode.sh @@ -22,7 +22,7 @@ function genWithContextParameter() { for ((n = 1; n<$how_many; n +=1)); do echo -n ", T$(($n))" done - echo -n ").Type, context: PostgresDecodingContext, file: String = #file, line: Int = #line) " + echo -n ").Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) " echo -n "-> AsyncThrowingMapSequence AsyncThrowingMapSequence Date: Thu, 27 Apr 2023 10:00:51 +0200 Subject: [PATCH 145/292] Fixes crash in queries that timeout (#351) --- .../New/PostgresChannelHandler.swift | 3 +- Tests/IntegrationTests/AsyncTests.swift | 38 +++++++++++++++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 5b9f4240..7411039c 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -273,8 +273,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .forwardStreamError(let error, let read, let cleanupContext): - self.rowStream!.receive(completion: .failure(error)) + let rowStream = self.rowStream! self.rowStream = nil + rowStream.receive(completion: .failure(error)) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } else if read { diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 6fa47c2a..ca9d80a1 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -36,13 +36,45 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await withTestConnection(on: eventLoop) { connection in let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) - var counter = 1 + var counter = 0 for try await element in rows.decode(Int.self, context: .default) { - XCTAssertEqual(element, counter) + XCTAssertEqual(element, counter + 1) counter += 1 } - XCTAssertEqual(counter, end + 1) + XCTAssertEqual(counter, end) + } + } + + func testSelectTimeoutWhileLongRunningQuery() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000000 + + try await withTestConnection(on: eventLoop) { connection -> () in + try await connection.query("SET statement_timeout=1000;", logger: .psqlTest) + + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 0 + do { + for try await element in rows.decode(Int.self, context: .default) { + XCTAssertEqual(element, counter + 1) + counter += 1 + } + XCTFail("Expected to get cancelled while reading the query") + } catch { + guard let error = error as? PSQLError else { return XCTFail("Unexpected error type") } + + print(error) + + XCTAssertEqual(error.code, .server) + XCTAssertEqual(error.serverInfo?[.severity], "ERROR") + } + + XCTAssertFalse(connection.isClosed, "Connection should survive!") } } From c692edafa9929d3d56ecb055c87f7035a4e6cb3f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 28 Apr 2023 10:26:36 +0200 Subject: [PATCH 146/292] Remove #filePath when used in an assertion (#355) --- Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift b/Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift index 1f74a26c..ab3e493f 100644 --- a/Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift +++ b/Sources/PostgresNIO/Deprecated/PostgresData+UInt.swift @@ -1,6 +1,6 @@ private func warn( _ old: Any.Type, mustBeConvertedTo new: Any.Type, - file: StaticString = #filePath, line: UInt = #line + file: StaticString = #file, line: UInt = #line ) { assertionFailure(""" Integer conversion unsafe. From dbf9c2eb596df39cba8ff3f74d74b2e6a31bd937 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 2 May 2023 13:25:52 +0200 Subject: [PATCH 147/292] Fix row stream cancel/error behavior (#353) --- Package.swift | 2 +- .../ConnectionStateMachine.swift | 10 +- .../ExtendedQueryStateMachine.swift | 12 ++- .../New/PostgresChannelHandler.swift | 4 +- Tests/IntegrationTests/AsyncTests.swift | 84 ++++++++++++++++ .../ExtendedQueryStateMachineTests.swift | 96 +++++++++++++++++++ 6 files changed, 200 insertions(+), 8 deletions(-) diff --git a/Package.swift b/Package.swift index 5aa9121e..52857cd2 100644 --- a/Package.swift +++ b/Package.swift @@ -14,7 +14,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.50.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.51.1"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.16.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.23.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index eeab0a81..65b43cd5 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -841,7 +841,7 @@ struct ConnectionStateMachine { // MARK: Consumer mutating func cancelQueryStream() -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { preconditionFailure("Tried to cancel stream without active query") } @@ -926,6 +926,8 @@ struct ConnectionStateMachine { .wait, .read: preconditionFailure("Expecting only failure actions if an error happened") + case .evaluateErrorAtConnectionLevel: + return .closeConnectionAndCleanup(cleanupContext) case .failQuery(let queryContext, with: let error): return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) case .forwardStreamError(let error, let read): @@ -1169,6 +1171,12 @@ extension ConnectionStateMachine { case .forwardStreamError(let error, let read): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + + case .evaluateErrorAtConnectionLevel(let error): + if let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) { + return .closeConnectionAndCleanup(cleanupContext) + } + return .wait case .read: return .read case .wait: diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index fdde1aa8..8b46fd0b 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -32,7 +32,9 @@ struct ExtendedQueryStateMachine { case failQuery(ExtendedQueryContext, with: PSQLError) case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) - + + case evaluateErrorAtConnectionLevel(PSQLError) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -422,11 +424,15 @@ struct ExtendedQueryStateMachine { .noDataMessageReceived(let context), .bindCompleteReceived(let context): self.state = .error(error) - return .failQuery(context, with: error) + if self.isCancelled { + return .evaluateErrorAtConnectionLevel(error) + } else { + return .failQuery(context, with: error) + } case .drain: self.state = .error(error) - return .forwardStreamError(error, read: false) + return .evaluateErrorAtConnectionLevel(error) case .streaming(_, var streamStateMachine): self.state = .error(error) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 7411039c..a3cd1e4e 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -273,9 +273,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .forwardStreamError(let error, let read, let cleanupContext): - let rowStream = self.rowStream! + self.rowStream!.receive(completion: .failure(error)) self.rowStream = nil - rowStream.receive(completion: .failure(error)) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } else if read { @@ -512,7 +511,6 @@ extension PostgresChannelHandler: PSQLRowsDataSource { guard self.rowStream === stream, let handlerContext = self.handlerContext else { return } - // we ignore this right now :) let action = self.state.cancelQueryStream() self.run(action, with: handlerContext) } diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index ca9d80a1..ed6910d1 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -75,6 +75,43 @@ final class AsyncPostgresConnectionTests: XCTestCase { } XCTAssertFalse(connection.isClosed, "Connection should survive!") + + for num in 0..<10 { + for try await decoded in try await connection.query("SELECT \(num);", logger: .psqlTest).decode(Int.self) { + XCTAssertEqual(decoded, num) + } + } + } + } + + func testConnectionSurvives1kQueriesWithATypo() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection -> () in + for _ in 0..<1000 { + do { + try await connection.query("SELECT generte_series(\(start), \(end));", logger: .psqlTest) + XCTFail("Expected to throw from the request") + } catch { + guard let error = error as? PSQLError else { return XCTFail("Unexpected error type: \(error)") } + + XCTAssertEqual(error.code, .server) + XCTAssertEqual(error.serverInfo?[.severity], "ERROR") + } + } + + // the connection survived all of this, we can still run normal queries: + + for num in 0..<10 { + for try await decoded in try await connection.query("SELECT \(num);", logger: .psqlTest).decode(Int.self) { + XCTAssertEqual(decoded, num) + } + } } } @@ -172,6 +209,53 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } #endif + + func testCancelTaskThatIsVeryLongRunningWhichAlsoFailsWhileInStreamingMode() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + // we cancel the query after 400ms. + // the server times out the query after 1sec. + + try await withTestConnection(on: eventLoop) { connection -> () in + try await connection.query("SET statement_timeout=1000;", logger: .psqlTest) // 1000 milliseconds + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + let start = 1 + let end = 100_000_000 + + let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 0 + do { + for try await element in rows.decode(Int.self, context: .default) { + XCTAssertEqual(element, counter + 1) + counter += 1 + } + XCTFail("Expected to get cancelled while reading the query") + XCTAssertEqual(counter, end) + } catch let error as CancellationError { + XCTAssertGreaterThanOrEqual(counter, 1) + // Expected + print("\(error)") + } catch { + XCTFail("Unexpected error: \(error)") + } + + XCTAssertTrue(Task.isCancelled) + XCTAssertFalse(connection.isClosed, "Connection should survive!") + } + + let delay: UInt64 = 400_000_000 // 400 milliseconds + try await Task.sleep(nanoseconds: delay) + + group.cancelAll() + } + + try await connection.query("SELECT 1;", logger: .psqlTest) + } + } } extension XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 21c78fd1..eac46e5f 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -181,4 +181,100 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.commandCompletedReceived("SELECT 4"), .wait) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } + + func testCancelQueryAfterServerError() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + let dataRows1: [DataRow] = [ + [ByteBuffer(string: "test1")], + [ByteBuffer(string: "test2")], + [ByteBuffer(string: "test3")] + ] + for row in dataRows1 { + XCTAssertEqual(state.dataRowReceived(row), .wait) + } + XCTAssertEqual(state.channelReadComplete(), .forwardRows(dataRows1)) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.requestQueryRows(), .read) + let dataRows2: [DataRow] = [ + [ByteBuffer(string: "test4")], + [ByteBuffer(string: "test5")], + [ByteBuffer(string: "test6")] + ] + for row in dataRows2 { + XCTAssertEqual(state.dataRowReceived(row), .wait) + } + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual(state.errorReceived(serverError), .forwardStreamError(.server(serverError), read: false, cleanupContext: .none)) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryErrorDoesNotKillConnection() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual( + state.errorReceived(serverError), .failQuery(queryContext, with: .server(serverError), cleanupContext: .none) + ) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryErrorAfterCancelDoesNotKillConnection() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "SELECT version()" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + XCTAssertEqual(state.cancelQueryStream(), .failQuery(queryContext, with: .queryCancelled, cleanupContext: .none)) + + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual(state.errorReceived(serverError), .wait) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + } From dd1b740fff3847c594bc258f834356e53b87f9de Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Wed, 3 May 2023 07:24:00 -0500 Subject: [PATCH 148/292] Add PostgresDecodingError debugDescription (#358) --- Sources/PostgresNIO/New/PSQLError.swift | 34 ++++++++++++++++- .../New/PostgresErrorTests.swift | 37 ++++++++++++++++++- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index a2fa9b5b..2c2bac2a 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -393,7 +393,7 @@ public struct PSQLError: Error { /// An error that may happen when a ``PostgresRow`` or ``PostgresCell`` is decoded to native Swift types. public struct PostgresDecodingError: Error, Equatable { - public struct Code: Hashable, Error { + public struct Code: Hashable, Error, CustomStringConvertible { enum Base { case missingData case typeMismatch @@ -409,6 +409,17 @@ public struct PostgresDecodingError: Error, Equatable { public static let missingData = Self.init(.missingData) public static let typeMismatch = Self.init(.typeMismatch) public static let failure = Self.init(.failure) + + public var description: String { + switch self.base { + case .missingData: + return "missingData" + case .typeMismatch: + return "typeMismatch" + case .failure: + return "failure" + } + } } /// The decoding error code @@ -476,3 +487,24 @@ extension PostgresDecodingError: CustomStringConvertible { "Database error" } } + +extension PostgresDecodingError: CustomDebugStringConvertible { + public var debugDescription: String { + var result = #"PostgresDecodingError(code: \#(self.code)"# + + result.append(#", columnName: \#(String(reflecting: self.columnName))"#) + result.append(#", columnIndex: \#(self.columnIndex)"#) + result.append(#", targetType: \#(String(reflecting: self.targetType))"#) + result.append(#", postgresType: \#(self.postgresType)"#) + result.append(#", postgresFormat: \#(self.postgresFormat)"#) + if let postgresData = self.postgresData { + result.append(#", postgresData: \#(postgresData.debugDescription)"#) // https://github.com/apple/swift-nio/pull/2418 + } + result.append(#", file: \#(self.file)"#) + result.append(#", line: \#(self.line)"#) + result.append(")") + + return result + } +} + diff --git a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift index b1b78ff9..639d6b5e 100644 --- a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift @@ -34,7 +34,7 @@ final class PostgresDecodingErrorTests: XCTestCase { } func testPostgresDecodingErrorDescription() { - let error = PostgresDecodingError( + let error1 = PostgresDecodingError( code: .typeMismatch, columnName: "column", columnIndex: 0, @@ -46,6 +46,39 @@ final class PostgresDecodingErrorTests: XCTestCase { line: 123 ) - XCTAssertEqual("\(error)", "Database error") + let error2 = PostgresDecodingError( + code: .missingData, + columnName: "column", + columnIndex: 0, + targetType: [[String: String]].self, + postgresType: .jsonbArray, + postgresFormat: .binary, + postgresData: nil, + file: "bar.swift", + line: 123 + ) + + // Plain description + XCTAssertEqual(String(describing: error1), "Database error") + XCTAssertEqual(String(describing: error2), "Database error") + + // Extended debugDescription + XCTAssertEqual(String(reflecting: error1), """ + PostgresDecodingError(code: typeMismatch,\ + columnName: "column", columnIndex: 0,\ + targetType: Swift.String,\ + postgresType: TEXT, postgresFormat: binary,\ + postgresData: \(error1.postgresData?.debugDescription ?? "nil"),\ + file: foo.swift, line: 123\ + ) + """) + XCTAssertEqual(String(reflecting: error2), """ + PostgresDecodingError(code: missingData,\ + columnName: "column", columnIndex: 0,\ + targetType: Swift.Array>,\ + postgresType: JSONB[], postgresFormat: binary,\ + file: bar.swift, line: 123\ + ) + """) } } From 2df54bc94607f44584ae6ffa74e3cd754fffafc7 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 4 May 2023 19:19:06 +0200 Subject: [PATCH 149/292] Merge pull request from GHSA-9cfh-vx93-84vv * Ensure empty incoming buffer when TLS negotiation starts * cleanup * Use real nio version --- Package.swift | 2 +- .../ConnectionStateMachine.swift | 38 ++++++++---------- Sources/PostgresNIO/New/PSQLError.swift | 6 +++ .../New/PostgresChannelHandler.swift | 2 +- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 2 + .../ConnectionStateMachineTests.swift | 13 +++++-- .../New/PostgresChannelHandlerTests.swift | 39 ++++++++++++++++++- 7 files changed, 74 insertions(+), 28 deletions(-) diff --git a/Package.swift b/Package.swift index 52857cd2..c1cb4bda 100644 --- a/Package.swift +++ b/Package.swift @@ -14,7 +14,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.51.1"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.52.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.16.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.23.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 65b43cd5..563bb026 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -228,9 +228,12 @@ struct ConnectionStateMachine { } } - mutating func sslSupportedReceived() -> ConnectionAction { + mutating func sslSupportedReceived(unprocessedBytes: Int) -> ConnectionAction { switch self.state { case .sslRequestSent: + if unprocessedBytes > 0 { + return self.closeConnectionAndCleanup(.receivedUnencryptedDataAfterSSLRequest) + } self.state = .sslNegotiated return .establishSSLConnection @@ -1079,9 +1082,18 @@ extension ConnectionStateMachine { extension ConnectionStateMachine { func shouldCloseConnection(reason error: PSQLError) -> Bool { switch error.code.base { - case .sslUnsupported: - return true - case .failedToAddSSLHandler: + case .failedToAddSSLHandler, + .receivedUnencryptedDataAfterSSLRequest, + .sslUnsupported, + .messageDecodingFailure, + .unexpectedBackendMessage, + .unsupportedAuthMechanism, + .authMechanismRequiresPassword, + .saslError, + .tooManyParameters, + .invalidCommandTag, + .connectionError, + .uncleanShutdown: return true case .queryCancelled: return false @@ -1097,28 +1109,10 @@ extension ConnectionStateMachine { } return false - case .messageDecodingFailure: - return true - case .unexpectedBackendMessage: - return true - case .unsupportedAuthMechanism: - return true - case .authMechanismRequiresPassword: - return true - case .saslError: - return true - case .tooManyParameters: - return true - case .invalidCommandTag: - return true case .connectionQuiescing: preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") case .connectionClosed: preconditionFailure("Pure client error, that is thrown directly and should never ") - case .connectionError: - return true - case .uncleanShutdown: - return true } } diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 2c2bac2a..08b6a01e 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -7,6 +7,7 @@ public struct PSQLError: Error { enum Base: Sendable, Hashable { case sslUnsupported case failedToAddSSLHandler + case receivedUnencryptedDataAfterSSLRequest case server case messageDecodingFailure case unexpectedBackendMessage @@ -31,6 +32,7 @@ public struct PSQLError: Error { public static let sslUnsupported = Self.init(.sslUnsupported) public static let failedToAddSSLHandler = Self(.failedToAddSSLHandler) + public static let receivedUnencryptedDataAfterSSLRequest = Self(.receivedUnencryptedDataAfterSSLRequest) public static let server = Self(.server) public static let messageDecodingFailure = Self(.messageDecodingFailure) public static let unexpectedBackendMessage = Self(.unexpectedBackendMessage) @@ -51,6 +53,8 @@ public struct PSQLError: Error { return "sslUnsupported" case .failedToAddSSLHandler: return "failedToAddSSLHandler" + case .receivedUnencryptedDataAfterSSLRequest: + return "receivedUnencryptedDataAfterSSLRequest" case .server: return "server" case .messageDecodingFailure: @@ -343,6 +347,8 @@ public struct PSQLError: Error { static var uncleanShutdown: PSQLError { PSQLError(code: .uncleanShutdown) } + static var receivedUnencryptedDataAfterSSLRequest: PSQLError { PSQLError(code: .receivedUnencryptedDataAfterSSLRequest) } + static func server(_ response: PostgresBackendMessage.ErrorResponse) -> PSQLError { var error = PSQLError(code: .server) error.serverInfo = .init(response) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index a3cd1e4e..84f07d47 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -139,7 +139,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .rowDescription(let rowDescription): action = self.state.rowDescriptionReceived(rowDescription) case .sslSupported: - action = self.state.sslSupportedReceived() + action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes) case .sslUnsupported: action = self.state.sslUnsupportedReceived() } diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 55870f8a..ff9773f5 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -31,6 +31,8 @@ extension PSQLError { return PostgresError.protocol("Unsupported auth scheme: \(message)") case .authMechanismRequiresPassword: return PostgresError.protocol("Unable to authenticate without password") + case .receivedUnencryptedDataAfterSSLRequest: + return PostgresError.protocol("Received unencrypted data after SSL request") case .saslError: return self.underlying ?? self case .tooManyParameters, .invalidCommandTag: diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index eaf427d5..289665fb 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -19,20 +19,27 @@ class ConnectionStateMachineTests: XCTestCase { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) - XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection) + XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 0), .establishSSLConnection) XCTAssertEqual(state.sslHandlerAdded(), .wait) XCTAssertEqual(state.sslEstablished(), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) let salt: (UInt8, UInt8, UInt8, UInt8) = (0,1,2,3) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) } - + + func testSSLStartupFailureTooManyBytesRemaining() { + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) + let failError = PSQLError.receivedUnencryptedDataAfterSSLRequest + XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 1), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) + } + func testSSLStartupFailHandler() { struct SSLHandlerAddError: Error, Equatable {} var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) - XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection) + XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 0), .establishSSLConnection) let failError = PSQLError.failedToAddSSLHandler(underlying: SSLHandlerAddError()) XCTAssertEqual(state.errorHappened(failError), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 9e3bbefa..7ab0ce30 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -77,7 +77,44 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(startupMessage.parameters.database, config.database) XCTAssertEqual(startupMessage.parameters.replication, .false) } - + + func testEstablishSSLCallbackIsNotCalledIfSSLIsSupportedButAnotherMEssageIsSentAsWell() { + var config = self.testConnectionConfiguration() + XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) + var addSSLCallbackIsHit = false + let handler = PostgresChannelHandler(configuration: config) { channel in + addSSLCallbackIsHit = true + } + let eventHandler = TestEventHandler() + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + handler, + eventHandler + ]) + + var maybeMessage: PostgresFrontendMessage? + XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) + XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) + guard case .sslRequest(let request) = maybeMessage else { + return XCTFail("Unexpected message") + } + + XCTAssertEqual(request.code, 80877103) + + var responseBuffer = ByteBuffer() + responseBuffer.writeInteger(UInt8(ascii: "S")) + responseBuffer.writeInteger(UInt8(ascii: "1")) + XCTAssertNoThrow(try embedded.writeInbound(responseBuffer)) + + XCTAssertFalse(addSSLCallbackIsHit) + + // the event handler should have seen an error + XCTAssertEqual(eventHandler.errors.count, 1) + + // the connections should be closed + XCTAssertFalse(embedded.isActive) + } + func testSSLUnsupportedClosesConnection() throws { let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) From e4cc928a07c84b009dc0baaaf5e69e426ae40d56 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 5 May 2023 01:05:35 -0500 Subject: [PATCH 150/292] Various cleanups to CI (#359) * Don't need the BUILDING_DOCC hack (and thus the extra exports check CI job) anymore. MASSIVELY simplify the projectboard workflow. Reenable CI coverage for the main nightly snapshot since the bug that was crashing the compiler's been fixed. * Use the new reusable project boards workflow that needs no extra work --- .github/workflows/projectboard.yml | 28 +++------------------ .github/workflows/test.yml | 15 ----------- Sources/PostgresNIO/Utilities/Exports.swift | 2 +- 3 files changed, 5 insertions(+), 40 deletions(-) diff --git a/.github/workflows/projectboard.yml b/.github/workflows/projectboard.yml index b857f6ee..a0e6d988 100644 --- a/.github/workflows/projectboard.yml +++ b/.github/workflows/projectboard.yml @@ -5,27 +5,7 @@ on: types: [reopened, closed, labeled, unlabeled, assigned, unassigned] jobs: - setup_matrix_input: - runs-on: ubuntu-latest - - steps: - - id: set-matrix - run: | - output=$(curl ${{ github.event.issue.url }}/labels | jq '.[] | .name') - - echo '======================' - echo 'Process incoming data' - echo '======================' - json=$(echo $output | sed 's/"\s"/","/g') - echo $json - echo "::set-output name=matrix::$(echo $json)" - outputs: - issueTags: ${{ steps.set-matrix.outputs.matrix }} - - Manage_project_issues: - needs: setup_matrix_input - uses: vapor/ci/.github/workflows/issues-to-project-board.yml@main - with: - labelsJson: ${{ needs.setup_matrix_input.outputs.issueTags }} - secrets: - PROJECT_BOARD_AUTOMATION_PAT: "${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }}" + update_project_boards: + name: Update project boards + uses: vapor/ci/.github/workflows/update-project-boards-for-issue.yml@reusable-workflows + secrets: inherit diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 66516611..8f1f139d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,9 +26,6 @@ jobs: # https://github.com/apple/swift-package-manager/issues/5853 - container: swift:5.8-jammy coverage: false - # https://github.com/apple/swift/issues/65064 - - container: swiftlang/swift:nightly-main-jammy - coverage: false container: ${{ matrix.container }} runs-on: ubuntu-latest env: @@ -181,15 +178,3 @@ jobs: run: git config --global --add safe.directory ${GITHUB_WORKSPACE} - name: API breaking changes run: swift package diagnose-api-breaking-changes origin/main - - test-exports: - name: Test exports - runs-on: ubuntu-latest - container: swift:5.8-jammy - steps: - - name: Check out package - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - name: Build - run: swift build -Xswiftc -DBUILDING_DOCC diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 5fc86b74..204df50c 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -4,7 +4,7 @@ @_documentation(visibility: internal) @_exported import NIOSSL @_documentation(visibility: internal) @_exported import struct Logging.Logger -#elseif !BUILDING_DOCC +#else // TODO: Remove this with the next major release! @_exported import NIO From dbefcb022ca0148cc4cd8efd805246db4d3ccaee Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sat, 6 May 2023 03:42:57 -0500 Subject: [PATCH 151/292] Fixup code coverage config in CI (#362) Fixup code coverage config - swift-codecov-action now works around the Swift 5.8 problem, we don't want the unittests flag or need verbose output, and we shouldn't fail CI if uploads fail (they fail often) --- .github/workflows/test.yml | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8f1f139d..be74e3b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,11 +21,6 @@ jobs: - swift:5.8-jammy - swiftlang/swift:nightly-5.9-jammy - swiftlang/swift:nightly-main-jammy - include: - - coverage: true - # https://github.com/apple/swift-package-manager/issues/5853 - - container: swift:5.8-jammy - coverage: false container: ${{ matrix.container }} runs-on: ubuntu-latest env: @@ -34,19 +29,12 @@ jobs: - name: Check out package uses: actions/checkout@v3 - name: Run unit tests with code coverage and Thread Sanitizer - shell: bash - run: | - coverage=$( [[ '${{ matrix.coverage }}' == 'true' ]] && echo -n '--enable-code-coverage' || true ) - swift test --filter=^PostgresNIOTests --sanitize=thread ${coverage} + run: swift test --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage - name: Submit coverage report to Codecov.io - if: ${{ matrix.coverage }} uses: vapor/swift-codecov-action@v0.2 with: - cc_flags: 'unittests' cc_env_vars: 'SWIFT_VERSION,SWIFT_PLATFORM,RUNNER_OS,RUNNER_ARCH' - cc_fail_ci_if_error: true - cc_verbose: true - cc_dry_run: false + cc_fail_ci_if_error: false linux-integration-and-dependencies: if: github.event_name == 'pull_request' From 7524022ccfc4857ec399c55d6f92c2ada5420d9a Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sun, 7 May 2023 12:05:27 -0500 Subject: [PATCH 152/292] Make PostgresCodable typealias public (#363) --- Sources/PostgresNIO/New/PostgresCodable.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 3aa1a24f..36937de4 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -81,7 +81,7 @@ extension PostgresDecodable { } /// A type that can be encoded into and decoded from a postgres binary format -typealias PostgresCodable = PostgresEncodable & PostgresDecodable +public typealias PostgresCodable = PostgresEncodable & PostgresDecodable extension PostgresEncodable { @inlinable From a290e4e73bc5a912d5c2289b2d173cf31636eeee Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 8 May 2023 22:45:51 +0200 Subject: [PATCH 153/292] Add `testSelectActiveConnection` test (#364) --- Tests/IntegrationTests/AsyncTests.swift | 40 ++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index ed6910d1..c96c81f5 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -46,6 +46,44 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testSelectActiveConnection() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let query: PostgresQuery = """ + SELECT + pid + ,datname + ,usename + ,application_name + ,client_hostname + ,client_port + ,backend_start + ,query_start + ,query + ,state + FROM pg_stat_activity + WHERE state = 'active'; + """ + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.query(query, logger: .psqlTest) + var counter = 0 + + for try await element in rows.decode((Int, String, String, String, String?, Int, Date, Date, String, String).self) { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "localhost") + XCTAssertEqual(element.2, env("POSTGRES_USER") ?? "test_username") + + XCTAssertEqual(element.8, query.sql) + XCTAssertEqual(element.9, "active") + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + } + } + func testSelectTimeoutWhileLongRunningQuery() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -273,7 +311,7 @@ extension XCTestCase { try await connection.close() return result } catch { - XCTFail("Unexpected error: \(error)", file: file, line: line) + XCTFail("Unexpected error: \(String(reflecting: error))", file: file, line: line) try await connection.close() throw error } From 8981a236bf4fc9e1185e64045836dbf6dbffec3c Mon Sep 17 00:00:00 2001 From: Zach Rausnitz Date: Tue, 9 May 2023 10:33:55 -0400 Subject: [PATCH 154/292] Add support for int4range, int8range, int4range[], int8range[] (#330) Co-authored-by: Fabian Fett --- .../PostgresNIO/Data/PostgresDataType.swift | 27 ++ .../New/Data/Array+PostgresCodable.swift | 12 + .../New/Data/Range+PostgresCodable.swift | 307 ++++++++++++++++++ Tests/IntegrationTests/PostgresNIOTests.swift | 216 ++++++++++++ .../New/Data/Array+PSQLCodableTests.swift | 16 + .../New/Data/Date+PSQLCodableTests.swift | 2 +- .../New/Data/Range+PSQLCodableTests.swift | 105 ++++++ 7 files changed, 684 insertions(+), 1 deletion(-) create mode 100644 Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift create mode 100644 Tests/PostgresNIOTests/New/Data/Range+PSQLCodableTests.swift diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index 50d2b0eb..d57f2529 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -115,6 +115,14 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri public static let jsonb = PostgresDataType(3802) /// `3807` _jsonb public static let jsonbArray = PostgresDataType(3807) + /// `3904` + public static let int4Range = PostgresDataType(3904) + /// `3905` _int4range + public static let int4RangeArray = PostgresDataType(3905) + /// `3926` + public static let int8Range = PostgresDataType(3926) + /// `3927` _int8range + public static let int8RangeArray = PostgresDataType(3927) /// The raw data type code recognized by PostgreSQL. public var rawValue: UInt32 @@ -180,6 +188,10 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .uuidArray: return "UUID[]" case .jsonb: return "JSONB" case .jsonbArray: return "JSONB[]" + case .int4Range: return "INT4RANGE" + case .int4RangeArray: return "INT4RANGE[]" + case .int8Range: return "INT8RANGE" + case .int8RangeArray: return "INT8RANGE[]" default: return nil } } @@ -201,6 +213,8 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .jsonb: return .jsonbArray case .text: return .textArray case .varchar: return .varcharArray + case .int4Range: return .int4RangeArray + case .int8Range: return .int8RangeArray default: return nil } } @@ -223,6 +237,19 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .jsonbArray: return .jsonb case .textArray: return .text case .varcharArray: return .varchar + case .int4RangeArray: return .int4Range + case .int8RangeArray: return .int8Range + default: return nil + } + } + + /// Returns the bound type for this type if one is known. + /// Returns nil if this is not a range type. + @usableFromInline + internal var boundType: PostgresDataType? { + switch self { + case .int4Range: return .int4 + case .int8Range: return .int8 default: return nil } } diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index 2c57b605..fb2b62e3 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -85,6 +85,18 @@ extension UUID: PostgresArrayEncodable { public static var psqlArrayType: PostgresDataType { .uuidArray } } +extension Range: PostgresArrayDecodable where Bound: PostgresRangeArrayDecodable {} + +extension Range: PostgresArrayEncodable where Bound: PostgresRangeArrayEncodable { + public static var psqlArrayType: PostgresDataType { Bound.psqlRangeArrayType } +} + +extension ClosedRange: PostgresArrayDecodable where Bound: PostgresRangeArrayDecodable {} + +extension ClosedRange: PostgresArrayEncodable where Bound: PostgresRangeArrayEncodable { + public static var psqlArrayType: PostgresDataType { Bound.psqlRangeArrayType } +} + // MARK: Array conformances extension Array: PostgresEncodable where Element: PostgresArrayEncodable { diff --git a/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift new file mode 100644 index 00000000..929330ef --- /dev/null +++ b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift @@ -0,0 +1,307 @@ +import NIOCore + +// MARK: Protocols + +/// A type that can be encoded into a Postgres range type where it is the bound type +public protocol PostgresRangeEncodable: PostgresNonThrowingEncodable { + static var psqlRangeType: PostgresDataType { get } +} + +/// A type that can be decoded into a Swift RangeExpression type from a Postgres range where it is the bound type +public protocol PostgresRangeDecodable: PostgresDecodable { + /// If a Postgres range type has a well-defined step, + /// Postgres automatically converts it to a canonical form. + /// Types such as `int4range` get converted to upper-bound-exclusive. + /// This method is needed when converting an upper bound to inclusive. + /// It should throw if the type lacks a well-defined step. + func upperBoundExclusiveToUpperBoundInclusive() throws -> Self + + /// Postgres does not store any bound values for empty ranges, + /// but Swift requires a value to initialize an empty Range. + static var valueForEmptyRange: Self { get } +} + +/// A type that can be encoded into a Postgres range array type where it is the bound type +public protocol PostgresRangeArrayEncodable: PostgresRangeEncodable { + static var psqlRangeArrayType: PostgresDataType { get } +} + +/// A type that can be decoded into a Swift RangeExpression array type from a Postgres range array where it is the bound type +public protocol PostgresRangeArrayDecodable: PostgresRangeDecodable {} + +// MARK: Bound conformances + +extension FixedWidthInteger where Self: PostgresRangeDecodable { + public func upperBoundExclusiveToUpperBoundInclusive() -> Self { + return self - 1 + } + + public static var valueForEmptyRange: Self { + return .zero + } +} + +extension Int32: PostgresRangeEncodable { + public static var psqlRangeType: PostgresDataType { return .int4Range } +} + +extension Int32: PostgresRangeDecodable {} + +extension Int32: PostgresRangeArrayEncodable { + public static var psqlRangeArrayType: PostgresDataType { return .int4RangeArray } +} + +extension Int32: PostgresRangeArrayDecodable {} + +extension Int64: PostgresRangeEncodable { + public static var psqlRangeType: PostgresDataType { return .int8Range } +} + +extension Int64: PostgresRangeDecodable {} + +extension Int64: PostgresRangeArrayEncodable { + public static var psqlRangeArrayType: PostgresDataType { return .int8RangeArray } +} + +extension Int64: PostgresRangeArrayDecodable {} + +// MARK: PostgresRange + +@usableFromInline +struct PostgresRange { + @usableFromInline let lowerBound: B? + @usableFromInline let upperBound: B? + @usableFromInline let isLowerBoundInclusive: Bool + @usableFromInline let isUpperBoundInclusive: Bool + + @inlinable + init( + lowerBound: B?, + upperBound: B?, + isLowerBoundInclusive: Bool, + isUpperBoundInclusive: Bool + ) { + self.lowerBound = lowerBound + self.upperBound = upperBound + self.isLowerBoundInclusive = isLowerBoundInclusive + self.isUpperBoundInclusive = isUpperBoundInclusive + } +} + +/// Used by Postgres to represent certain range properties +@usableFromInline +struct PostgresRangeFlag { + @usableFromInline static let isEmpty: UInt8 = 0x01 + @usableFromInline static let isLowerBoundInclusive: UInt8 = 0x02 + @usableFromInline static let isUpperBoundInclusive: UInt8 = 0x04 +} + +extension PostgresRange: PostgresDecodable where B: PostgresRangeDecodable { + @inlinable + init( + from byteBuffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + guard case .binary = format else { + throw PostgresDecodingError.Code.failure + } + + guard let boundType: PostgresDataType = type.boundType else { + throw PostgresDecodingError.Code.failure + } + + // flags byte contains certain properties of the range + guard let flags: UInt8 = byteBuffer.readInteger(as: UInt8.self) else { + throw PostgresDecodingError.Code.failure + } + + let isEmpty: Bool = flags & PostgresRangeFlag.isEmpty != 0 + if isEmpty { + self = PostgresRange( + lowerBound: B.valueForEmptyRange, + upperBound: B.valueForEmptyRange, + isLowerBoundInclusive: true, + isUpperBoundInclusive: false + ) + return + } + + guard let lowerBoundSize: Int32 = byteBuffer.readInteger(as: Int32.self), + Int(lowerBoundSize) == MemoryLayout.size, + var lowerBoundBytes: ByteBuffer = byteBuffer.readSlice(length: Int(lowerBoundSize)) + else { + throw PostgresDecodingError.Code.failure + } + + let lowerBound: B = try B(from: &lowerBoundBytes, type: boundType, format: format, context: context) + + guard let upperBoundSize = byteBuffer.readInteger(as: Int32.self), + Int(upperBoundSize) == MemoryLayout.size, + var upperBoundBytes: ByteBuffer = byteBuffer.readSlice(length: Int(upperBoundSize)) + else { + throw PostgresDecodingError.Code.failure + } + + let upperBound: B = try B(from: &upperBoundBytes, type: boundType, format: format, context: context) + + let isLowerBoundInclusive: Bool = flags & PostgresRangeFlag.isLowerBoundInclusive != 0 + let isUpperBoundInclusive: Bool = flags & PostgresRangeFlag.isUpperBoundInclusive != 0 + + self = PostgresRange( + lowerBound: lowerBound, + upperBound: upperBound, + isLowerBoundInclusive: isLowerBoundInclusive, + isUpperBoundInclusive: isUpperBoundInclusive + ) + + } +} + +extension PostgresRange: PostgresEncodable & PostgresNonThrowingEncodable where B: PostgresRangeEncodable { + @usableFromInline + static var psqlType: PostgresDataType { return B.psqlRangeType } + + @usableFromInline + static var psqlFormat: PostgresFormat { return .binary } + + @inlinable + func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) { + // flags byte contains certain properties of the range + var flags: UInt8 = 0 + if self.isLowerBoundInclusive { + flags |= PostgresRangeFlag.isLowerBoundInclusive + } + if self.isUpperBoundInclusive { + flags |= PostgresRangeFlag.isUpperBoundInclusive + } + + let boundMemorySize = Int32(MemoryLayout.size) + + byteBuffer.writeInteger(flags) + if let lowerBound: B = self.lowerBound { + byteBuffer.writeInteger(boundMemorySize) + lowerBound.encode(into: &byteBuffer, context: context) + } + if let upperBound: B = self.upperBound { + byteBuffer.writeInteger(boundMemorySize) + upperBound.encode(into: &byteBuffer, context: context) + } + } +} + +extension PostgresRange where B: Comparable { + @inlinable + init(range: Range) { + self.lowerBound = range.lowerBound + self.upperBound = range.upperBound + self.isLowerBoundInclusive = true + self.isUpperBoundInclusive = false + } + + @inlinable + init(closedRange: ClosedRange) { + self.lowerBound = closedRange.lowerBound + self.upperBound = closedRange.upperBound + self.isLowerBoundInclusive = true + self.isUpperBoundInclusive = true + } +} + +// MARK: Range + +extension Range: PostgresEncodable where Bound: PostgresRangeEncodable { + public static var psqlType: PostgresDataType { return Bound.psqlRangeType } + public static var psqlFormat: PostgresFormat { return .binary } + + @inlinable + public func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + let postgresRange = PostgresRange(range: self) + postgresRange.encode(into: &byteBuffer, context: context) + } +} + +extension Range: PostgresNonThrowingEncodable where Bound: PostgresRangeEncodable {} + +extension Range: PostgresDecodable where Bound: PostgresRangeDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + let postgresRange = try PostgresRange( + from: &buffer, + type: type, + format: format, + context: context + ) + + guard let lowerBound: Bound = postgresRange.lowerBound, + let upperBound: Bound = postgresRange.upperBound, + postgresRange.isLowerBoundInclusive, + !postgresRange.isUpperBoundInclusive + else { + throw PostgresDecodingError.Code.failure + } + + self = lowerBound..( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) { + let postgresRange = PostgresRange(closedRange: self) + postgresRange.encode(into: &byteBuffer, context: context) + } +} + +extension ClosedRange: PostgresNonThrowingEncodable where Bound: PostgresRangeEncodable {} + +extension ClosedRange: PostgresDecodable where Bound: PostgresRangeDecodable { + @inlinable + public init( + from buffer: inout ByteBuffer, + type: PostgresDataType, + format: PostgresFormat, + context: PostgresDecodingContext + ) throws { + let postgresRange = try PostgresRange( + from: &buffer, + type: type, + format: format, + context: context + ) + + guard let lowerBound: Bound = postgresRange.lowerBound, + var upperBound: Bound = postgresRange.upperBound, + postgresRange.isLowerBoundInclusive + else { + throw PostgresDecodingError.Code.failure + } + + if !postgresRange.isUpperBoundInclusive { + upperBound = try upperBound.upperBoundExclusiveToUpperBoundInclusive() + } + + if lowerBound > upperBound { + throw PostgresDecodingError.Code.failure + } + + self = lowerBound...upperBound + } +} diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 348e6eb6..19c4e167 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -374,6 +374,120 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(UUID(uuidString: row?[data: "id"].string ?? ""), UUID(uuidString: "123E4567-E89B-12D3-A456-426655440000")) } + func testInt4Range() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + struct Model: Decodable { + let range: Range + } + let results1: PostgresQueryResult = try await conn.query(""" + SELECT + '[\(Int32.min), \(Int32.max))'::int4range AS range + """).get() + XCTAssertEqual(results1.count, 1) + var row = results1.first?.makeRandomAccess() + let expectedRange: Range = Int32.min...self, context: .default) + XCTAssertEqual(decodedRange, expectedRange) + + let results2 = try await conn.query(""" + SELECT + ARRAY[ + '[0, 1)'::int4range, + '[10, 11)'::int4range + ] AS ranges + """).get() + XCTAssertEqual(results2.count, 1) + row = results2.first?.makeRandomAccess() + let decodedRangeArray = try row?.decode(column: "ranges", as: [Range].self, context: .default) + let decodedClosedRangeArray = try row?.decode(column: "ranges", as: [ClosedRange].self, context: .default) + XCTAssertEqual(decodedRangeArray, [0..<1, 10..<11]) + XCTAssertEqual(decodedClosedRangeArray, [0...0, 10...10]) + } + + func testEmptyInt4Range() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + struct Model: Decodable { + let range: Range + } + let randomValue = Int32.random(in: Int32.min...Int32.max) + let results: PostgresQueryResult = try await conn.query(""" + SELECT + '[\(randomValue),\(randomValue))'::int4range AS range + """).get() + XCTAssertEqual(results.count, 1) + let row = results.first?.makeRandomAccess() + let expectedRange: Range = Int32.valueForEmptyRange...self, context: .default) + XCTAssertEqual(decodedRange, expectedRange) + + XCTAssertThrowsError( + try row?.decode(column: "range", as: ClosedRange.self, context: .default) + ) + } + + func testInt8Range() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + struct Model: Decodable { + let range: Range + } + let results1: PostgresQueryResult = try await conn.query(""" + SELECT + '[\(Int64.min), \(Int64.max))'::int8range AS range + """).get() + XCTAssertEqual(results1.count, 1) + var row = results1.first?.makeRandomAccess() + let expectedRange: Range = Int64.min...self, context: .default) + XCTAssertEqual(decodedRange, expectedRange) + + let results2: PostgresQueryResult = try await conn.query(""" + SELECT + ARRAY[ + '[0, 1)'::int8range, + '[10, 11)'::int8range + ] AS ranges + """).get() + XCTAssertEqual(results2.count, 1) + row = results2.first?.makeRandomAccess() + let decodedRangeArray = try row?.decode(column: "ranges", as: [Range].self, context: .default) + let decodedClosedRangeArray = try row?.decode(column: "ranges", as: [ClosedRange].self, context: .default) + XCTAssertEqual(decodedRangeArray, [0..<1, 10..<11]) + XCTAssertEqual(decodedClosedRangeArray, [0...0, 10...10]) + } + + func testEmptyInt8Range() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + struct Model: Decodable { + let range: Range + } + let randomValue = Int64.random(in: Int64.min...Int64.max) + let results: PostgresQueryResult = try await conn.query(""" + SELECT + '[\(randomValue),\(randomValue))'::int8range AS range + """).get() + XCTAssertEqual(results.count, 1) + let row = results.first?.makeRandomAccess() + let expectedRange: Range = Int64.valueForEmptyRange...self, context: .default) + XCTAssertEqual(decodedRange, expectedRange) + + XCTAssertThrowsError( + try row?.decode(column: "range", as: ClosedRange.self, context: .default) + ) + } + func testDates() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) @@ -771,6 +885,108 @@ final class PostgresNIOTests: XCTestCase { } } + func testInt4RangeSerialize() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + do { + let range: Range = Int32.min..? = try row?.decode(Range.self, context: .default) + XCTAssertEqual(range, decodedRange) + } + do { + let emptyRange: Range = Int32.min..? = try row?.decode(Range.self, context: .default) + let expectedRange: Range = Int32.valueForEmptyRange.. = Int32.min...(Int32.max - 1) + var binds = PostgresBindings() + binds.append(closedRange, context: .default) + let query = PostgresQuery( + unsafeSQL: "select $1::int4range as range", + binds: binds + ) + let rowSequence: PostgresRowSequence? = try await conn.query(query, logger: .psqlTest) + var rowIterator: PostgresRowSequence.AsyncIterator? = rowSequence?.makeAsyncIterator() + let row: PostgresRow? = try await rowIterator?.next() + let decodedClosedRange: ClosedRange? = try row?.decode(ClosedRange.self, context: .default) + XCTAssertEqual(closedRange, decodedClosedRange) + } + } + + func testInt8RangeSerialize() async throws { + let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() + self.addTeardownBlock { + try await conn.close() + } + do { + let range: Range = Int64.min..? = try row?.decode(Range.self, context: .default) + XCTAssertEqual(range, decodedRange) + } + do { + let emptyRange: Range = Int64.min..? = try row?.decode(Range.self, context: .default) + let expectedRange: Range = Int64.valueForEmptyRange.. = Int64.min...(Int64.max - 1) + var binds = PostgresBindings() + binds.append(closedRange, context: .default) + let query = PostgresQuery( + unsafeSQL: "select $1::int8range as range", + binds: binds + ) + let rowSequence: PostgresRowSequence? = try await conn.query(query, logger: .psqlTest) + var rowIterator: PostgresRowSequence.AsyncIterator? = rowSequence?.makeAsyncIterator() + let row: PostgresRow? = try await rowIterator?.next() + let decodedClosedRange: ClosedRange? = try row?.decode(ClosedRange.self, context: .default) + XCTAssertEqual(closedRange, decodedClosedRange) + } + } + func testRemoteTLSServer() { // postgres://uymgphwj:7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA@elmer.db.elephantsql.com:5432/uymgphwj var conn: PostgresConnection? diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index 0a1da7c6..79d47c30 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -55,6 +55,22 @@ class Array_PSQLCodableTests: XCTestCase { XCTAssertEqual(UUID.psqlArrayType, .uuidArray) XCTAssertEqual(UUID.psqlType, .uuid) XCTAssertEqual([UUID].psqlType, .uuidArray) + + XCTAssertEqual(Range.psqlArrayType, .int4RangeArray) + XCTAssertEqual(Range.psqlType, .int4Range) + XCTAssertEqual([Range].psqlType, .int4RangeArray) + + XCTAssertEqual(ClosedRange.psqlArrayType, .int4RangeArray) + XCTAssertEqual(ClosedRange.psqlType, .int4Range) + XCTAssertEqual([ClosedRange].psqlType, .int4RangeArray) + + XCTAssertEqual(Range.psqlArrayType, .int8RangeArray) + XCTAssertEqual(Range.psqlType, .int8Range) + XCTAssertEqual([Range].psqlType, .int8RangeArray) + + XCTAssertEqual(ClosedRange.psqlArrayType, .int8RangeArray) + XCTAssertEqual(ClosedRange.psqlType, .int8Range) + XCTAssertEqual([ClosedRange].psqlType, .int8RangeArray) } func testStringArrayRoundTrip() { diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index b08c2de2..769bde4b 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -68,7 +68,7 @@ class Date_PSQLCodableTests: XCTestCase { XCTAssertNotNil(lastDate) } - func testDecodeDateFailsWithToMuchData() { + func testDecodeDateFailsWithTooMuchData() { var buffer = ByteBuffer() buffer.writeInteger(Int64(0)) diff --git a/Tests/PostgresNIOTests/New/Data/Range+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Range+PSQLCodableTests.swift new file mode 100644 index 00000000..a040c3f4 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Data/Range+PSQLCodableTests.swift @@ -0,0 +1,105 @@ +import XCTest +import NIOCore +@testable import PostgresNIO + +class Range_PSQLCodableTests: XCTestCase { + func testInt32RangeRoundTrip() { + let lowerBound = Int32.min + let upperBound = Int32.max + let value: Range = lowerBound...psqlType, .int4Range) + XCTAssertEqual(buffer.readableBytes, 17) + XCTAssertEqual(buffer.getInteger(at: 0, as: UInt8.self), 2) + XCTAssertEqual(buffer.getInteger(at: 1, as: UInt32.self), 4) + XCTAssertEqual(buffer.getInteger(at: 5, as: Int32.self), lowerBound) + XCTAssertEqual(buffer.getInteger(at: 9, as: UInt32.self), 4) + XCTAssertEqual(buffer.getInteger(at: 13, as: Int32.self), upperBound) + + var result: Range? + XCTAssertNoThrow(result = try Range(from: &buffer, type: .int4Range, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testInt32ClosedRangeRoundTrip() { + let lowerBound = Int32.min + let upperBound = Int32.max - 1 + let value: ClosedRange = lowerBound...upperBound + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(ClosedRange.psqlType, .int4Range) + XCTAssertEqual(buffer.readableBytes, 17) + XCTAssertEqual(buffer.getInteger(at: 0, as: UInt8.self), 6) + XCTAssertEqual(buffer.getInteger(at: 1, as: UInt32.self), 4) + XCTAssertEqual(buffer.getInteger(at: 5, as: Int32.self), lowerBound) + XCTAssertEqual(buffer.getInteger(at: 9, as: UInt32.self), 4) + XCTAssertEqual(buffer.getInteger(at: 13, as: Int32.self), upperBound) + + var result: ClosedRange? + XCTAssertNoThrow(result = try ClosedRange(from: &buffer, type: .int4Range, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testInt64RangeRoundTrip() { + let lowerBound = Int64.min + let upperBound = Int64.max + let value: Range = lowerBound...psqlType, .int8Range) + XCTAssertEqual(buffer.readableBytes, 25) + XCTAssertEqual(buffer.getInteger(at: 0, as: UInt8.self), 2) + XCTAssertEqual(buffer.getInteger(at: 1, as: UInt32.self), 8) + XCTAssertEqual(buffer.getInteger(at: 5, as: Int64.self), lowerBound) + XCTAssertEqual(buffer.getInteger(at: 13, as: UInt32.self), 8) + XCTAssertEqual(buffer.getInteger(at: 17, as: Int64.self), upperBound) + + var result: Range? + XCTAssertNoThrow(result = try Range(from: &buffer, type: .int8Range, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testInt64ClosedRangeRoundTrip() { + let lowerBound = Int64.min + let upperBound = Int64.max - 1 + let value: ClosedRange = lowerBound...upperBound + + var buffer = ByteBuffer() + value.encode(into: &buffer, context: .default) + XCTAssertEqual(ClosedRange.psqlType, .int8Range) + XCTAssertEqual(buffer.readableBytes, 25) + XCTAssertEqual(buffer.getInteger(at: 0, as: UInt8.self), 6) + XCTAssertEqual(buffer.getInteger(at: 1, as: UInt32.self), 8) + XCTAssertEqual(buffer.getInteger(at: 5, as: Int64.self), lowerBound) + XCTAssertEqual(buffer.getInteger(at: 13, as: UInt32.self), 8) + XCTAssertEqual(buffer.getInteger(at: 17, as: Int64.self), upperBound) + + var result: ClosedRange? + XCTAssertNoThrow(result = try ClosedRange(from: &buffer, type: .int8Range, format: .binary, context: .default)) + XCTAssertEqual(value, result) + } + + func testInt64RangeDecodeFailureInvalidLength() { + var buffer = ByteBuffer() + buffer.writeInteger(0) + buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) + buffer.writeInteger(Int64.random(in: Int64.min...Int64.max)) + + XCTAssertThrowsError(try Range(from: &buffer, type: .int8Range, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } + + func testInt64RangeDecodeFailureWrongDataType() { + var buffer = ByteBuffer() + (Int64.min...Int64.max).encode(into: &buffer, context: .default) + + XCTAssertThrowsError(try Range(from: &buffer, type: .int8, format: .binary, context: .default)) { + XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + } + } +} From b3e1881ab0bdd8323ee927bc4bdf116285154972 Mon Sep 17 00:00:00 2001 From: Marius Seufzer <44228394+marius-se@users.noreply.github.com> Date: Wed, 10 May 2023 09:53:05 +0200 Subject: [PATCH 155/292] Add Postgres 15 to docker compose (#366) --- docker-compose.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index 600bdc99..68797651 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,9 @@ x-shared-config: &shared_config - 5432:5432 services: + psql-15: + image: postgres:15 + <<: *shared_config psql-14: image: postgres:14 <<: *shared_config From 62080bf919db03103137ce573ce110c19d14fe56 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 11 May 2023 02:22:19 -0500 Subject: [PATCH 156/292] Decode `.bpchar` as `String` (#368) * `.bpchar` is "blank-padded char", the low-level Postgres name for `character(N)` (the auto-padded form of `character varying`). `String`'s `PostgresCodable` conformance should thus recognize it. * Add .bpchar test and tell the invalid encoding test that `bpchar` is a string, for pity's sake. --- Sources/PostgresNIO/New/Data/String+PostgresCodable.swift | 1 + Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index f8e93e94..41091ab3 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -30,6 +30,7 @@ extension String: PostgresDecodable { ) throws { switch (format, type) { case (_, .varchar), + (_, .bpchar), (_, .text), (_, .name): // we can force unwrap here, since this method only fails if there are not enough diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 614749c1..6ff35130 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -20,7 +20,7 @@ class String_PSQLCodableTests: XCTestCase { buffer.writeString(expected) let dataTypes: [PostgresDataType] = [ - .text, .varchar, .name + .text, .varchar, .name, .bpchar ] for dataType in dataTypes { @@ -33,7 +33,7 @@ class String_PSQLCodableTests: XCTestCase { func testDecodeFailureFromInvalidType() { let buffer = ByteBuffer() - let dataTypes: [PostgresDataType] = [.bool, .float4Array, .float8Array, .bpchar] + let dataTypes: [PostgresDataType] = [.bool, .float4Array, .float8Array] for dataType in dataTypes { var loopBuffer = buffer From 2bfdd553305972405ec51499e694262a6dc8dbff Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 11 May 2023 10:28:33 +0200 Subject: [PATCH 157/292] Rename generic type from B to Bound in PostgresRange (#367) --- .../New/Data/Range+PostgresCodable.swift | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift index 929330ef..e5a3e60e 100644 --- a/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift @@ -68,16 +68,16 @@ extension Int64: PostgresRangeArrayDecodable {} // MARK: PostgresRange @usableFromInline -struct PostgresRange { - @usableFromInline let lowerBound: B? - @usableFromInline let upperBound: B? +struct PostgresRange { + @usableFromInline let lowerBound: Bound? + @usableFromInline let upperBound: Bound? @usableFromInline let isLowerBoundInclusive: Bool @usableFromInline let isUpperBoundInclusive: Bool @inlinable init( - lowerBound: B?, - upperBound: B?, + lowerBound: Bound?, + upperBound: Bound?, isLowerBoundInclusive: Bool, isUpperBoundInclusive: Bool ) { @@ -96,7 +96,7 @@ struct PostgresRangeFlag { @usableFromInline static let isUpperBoundInclusive: UInt8 = 0x04 } -extension PostgresRange: PostgresDecodable where B: PostgresRangeDecodable { +extension PostgresRange: PostgresDecodable where Bound: PostgresRangeDecodable { @inlinable init( from byteBuffer: inout ByteBuffer, @@ -119,9 +119,9 @@ extension PostgresRange: PostgresDecodable where B: PostgresRangeDecodable { let isEmpty: Bool = flags & PostgresRangeFlag.isEmpty != 0 if isEmpty { - self = PostgresRange( - lowerBound: B.valueForEmptyRange, - upperBound: B.valueForEmptyRange, + self = PostgresRange( + lowerBound: Bound.valueForEmptyRange, + upperBound: Bound.valueForEmptyRange, isLowerBoundInclusive: true, isUpperBoundInclusive: false ) @@ -129,27 +129,27 @@ extension PostgresRange: PostgresDecodable where B: PostgresRangeDecodable { } guard let lowerBoundSize: Int32 = byteBuffer.readInteger(as: Int32.self), - Int(lowerBoundSize) == MemoryLayout.size, + Int(lowerBoundSize) == MemoryLayout.size, var lowerBoundBytes: ByteBuffer = byteBuffer.readSlice(length: Int(lowerBoundSize)) else { throw PostgresDecodingError.Code.failure } - let lowerBound: B = try B(from: &lowerBoundBytes, type: boundType, format: format, context: context) + let lowerBound = try Bound(from: &lowerBoundBytes, type: boundType, format: format, context: context) guard let upperBoundSize = byteBuffer.readInteger(as: Int32.self), - Int(upperBoundSize) == MemoryLayout.size, + Int(upperBoundSize) == MemoryLayout.size, var upperBoundBytes: ByteBuffer = byteBuffer.readSlice(length: Int(upperBoundSize)) else { throw PostgresDecodingError.Code.failure } - let upperBound: B = try B(from: &upperBoundBytes, type: boundType, format: format, context: context) + let upperBound = try Bound(from: &upperBoundBytes, type: boundType, format: format, context: context) let isLowerBoundInclusive: Bool = flags & PostgresRangeFlag.isLowerBoundInclusive != 0 let isUpperBoundInclusive: Bool = flags & PostgresRangeFlag.isUpperBoundInclusive != 0 - self = PostgresRange( + self = PostgresRange( lowerBound: lowerBound, upperBound: upperBound, isLowerBoundInclusive: isLowerBoundInclusive, @@ -159,10 +159,10 @@ extension PostgresRange: PostgresDecodable where B: PostgresRangeDecodable { } } -extension PostgresRange: PostgresEncodable & PostgresNonThrowingEncodable where B: PostgresRangeEncodable { +extension PostgresRange: PostgresEncodable & PostgresNonThrowingEncodable where Bound: PostgresRangeEncodable { @usableFromInline - static var psqlType: PostgresDataType { return B.psqlRangeType } - + static var psqlType: PostgresDataType { return Bound.psqlRangeType } + @usableFromInline static var psqlFormat: PostgresFormat { return .binary } @@ -177,23 +177,23 @@ extension PostgresRange: PostgresEncodable & PostgresNonThrowingEncodable where flags |= PostgresRangeFlag.isUpperBoundInclusive } - let boundMemorySize = Int32(MemoryLayout.size) + let boundMemorySize = Int32(MemoryLayout.size) byteBuffer.writeInteger(flags) - if let lowerBound: B = self.lowerBound { + if let lowerBound = self.lowerBound { byteBuffer.writeInteger(boundMemorySize) lowerBound.encode(into: &byteBuffer, context: context) } - if let upperBound: B = self.upperBound { + if let upperBound = self.upperBound { byteBuffer.writeInteger(boundMemorySize) upperBound.encode(into: &byteBuffer, context: context) } } } -extension PostgresRange where B: Comparable { +extension PostgresRange where Bound: Comparable { @inlinable - init(range: Range) { + init(range: Range) { self.lowerBound = range.lowerBound self.upperBound = range.upperBound self.isLowerBoundInclusive = true @@ -201,7 +201,7 @@ extension PostgresRange where B: Comparable { } @inlinable - init(closedRange: ClosedRange) { + init(closedRange: ClosedRange) { self.lowerBound = closedRange.lowerBound self.upperBound = closedRange.upperBound self.isLowerBoundInclusive = true From fcb2e66880aebf92170b299d49afe96f40341740 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 11 May 2023 06:59:32 -0500 Subject: [PATCH 158/292] Add Swift version info to CI output (#369) --- .github/workflows/test.yml | 41 +++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index be74e3b8..25374cf3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,8 @@ on: pull_request: branches: - "*" +env: + LOG_LEVEL: info jobs: linux-unit: @@ -23,9 +25,15 @@ jobs: - swiftlang/swift:nightly-main-jammy container: ${{ matrix.container }} runs-on: ubuntu-latest - env: - LOG_LEVEL: debug steps: + - name: Display OS and Swift versions + shell: bash + run: | + if [[ '${{ contains(matrix.container, 'nightly') }}' == 'true' ]]; then + SWIFT_PLATFORM="$(source /etc/os-release && echo "${ID}${VERSION_ID}")" SWIFT_VERSION="$(cat /.swift_tag)" + printf 'SWIFT_PLATFORM=%s\nSWIFT_VERSION=%s\n' "${SWIFT_PLATFORM}" "${SWIFT_VERSION}" >>"${GITHUB_ENV}" + fi + printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package uses: actions/checkout@v3 - name: Run unit tests with code coverage and Thread Sanitizer @@ -57,7 +65,6 @@ jobs: volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: - LOG_LEVEL: debug # Unfortunately, fluent-postgres-driver details leak through here POSTGRES_DB: 'test_database' POSTGRES_DB_A: 'test_database' @@ -93,6 +100,9 @@ jobs: POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} steps: + - name: Display OS and Swift versions + run: | + printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package uses: actions/checkout@v3 with: { path: 'postgres-nio' } @@ -126,9 +136,8 @@ jobs: - scram-sha-256 xcode: - latest-stable - runs-on: macos-12 + runs-on: macos-13 env: - LOG_LEVEL: debug POSTGRES_HOSTNAME: 127.0.0.1 POSTGRES_USER: 'test_username' POSTGRES_PASSWORD: 'test_password' @@ -143,8 +152,8 @@ jobs: - name: Install Postgres, setup DB and auth, and wait for server start run: | export PATH="$(brew --prefix)/opt/${{ matrix.dbimage }}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - (brew unlink postgresql || true) && brew install ${{ matrix.dbimage }} && brew link --force ${{ matrix.dbimage }} - initdb --locale=C --auth-host ${{ matrix.dbauth }} -U $POSTGRES_USER --pwfile=<(echo $POSTGRES_PASSWORD) + (brew unlink postgresql || true) && brew install '${{ matrix.dbimage }}' && brew link --force '${{ matrix.dbimage }}' + initdb --locale=C --auth-host '${{ matrix.dbauth }}' -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait timeout-minutes: 2 - name: Checkout code @@ -157,12 +166,12 @@ jobs: runs-on: ubuntu-latest container: swift:5.8-jammy steps: - - name: Checkout - uses: actions/checkout@v3 - with: - fetch-depth: 0 - # https://github.com/actions/checkout/issues/766 - - name: Mark the workspace as safe - run: git config --global --add safe.directory ${GITHUB_WORKSPACE} - - name: API breaking changes - run: swift package diagnose-api-breaking-changes origin/main + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + # https://github.com/actions/checkout/issues/766 + - name: Mark the workspace as safe + run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" + - name: API breaking changes + run: swift package diagnose-api-breaking-changes origin/main From f87709218c4e444cb9593dd4dd4a3a0420950ee5 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 15 May 2023 07:26:42 -0500 Subject: [PATCH 159/292] No projectboard workflow for this package --- .github/workflows/projectboard.yml | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 .github/workflows/projectboard.yml diff --git a/.github/workflows/projectboard.yml b/.github/workflows/projectboard.yml deleted file mode 100644 index a0e6d988..00000000 --- a/.github/workflows/projectboard.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: issue-to-project-board-workflow -on: - # Trigger when an issue gets labeled or deleted - issues: - types: [reopened, closed, labeled, unlabeled, assigned, unassigned] - -jobs: - update_project_boards: - name: Update project boards - uses: vapor/ci/.github/workflows/update-project-boards-for-issue.yml@reusable-workflows - secrets: inherit From f1744c8e900b6ad859ee643a494ad481999cc779 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 2 Jun 2023 09:57:35 -0500 Subject: [PATCH 160/292] Add many, many missing types to PostgresDataType (#371) * Add lots of missing PSQL type codes and element/array type mappings, including geometry types, the reg* types, network address types, many missing array types, bit and varbit, the pseudo-types, text search types, more range types, jsonpath, and the multirange types. * More clearly document where the list of types etc. came from. --- .../PostgresNIO/Data/PostgresDataType.swift | 521 +++++++++++++++++- 1 file changed, 515 insertions(+), 6 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index d57f2529..ede60f47 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -24,8 +24,12 @@ extension PostgresFormat: Codable {} @available(*, deprecated, renamed: "PostgresFormat") public typealias PostgresFormatCode = PostgresFormat -/// The data type's raw object ID. -/// Use `select * from pg_type where oid = ;` to lookup more information. +/// Data types and their raw OIDs. +/// +/// Use `select * from pg_type where oid = ` to look up more information for a given type. +/// +/// This list was generated by running `select oid, typname from pg_type where oid < 10000 order by oid` +/// and manually trimming Postgres-internal types. public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStringConvertible { /// `0` public static let null = PostgresDataType(0) @@ -41,6 +45,8 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri public static let int8 = PostgresDataType(20) /// `21` public static let int2 = PostgresDataType(21) + /// `22` + public static let int2vector = PostgresDataType(22) /// `23` public static let int4 = PostgresDataType(23) /// `24` @@ -49,18 +55,75 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri public static let text = PostgresDataType(25) /// `26` public static let oid = PostgresDataType(26) + /// `27` + public static let tid = PostgresDataType(27) + /// `28` + public static let xid = PostgresDataType(28) + /// `29` + public static let cid = PostgresDataType(29) + /// `30` + public static let oidvector = PostgresDataType(30) + /// `32` + public static let pgDDLCommand = PostgresDataType(32) /// `114` public static let json = PostgresDataType(114) + /// `142` + public static let xml = PostgresDataType(142) + /// `143` + public static let xmlArray = PostgresDataType(143) /// `194` pg_node_tree + @available(*, deprecated, message: "This is internal to Postgres and should not be used.") public static let pgNodeTree = PostgresDataType(194) + /// `199` + public static let jsonArray = PostgresDataType(199) + /// `269` + public static let tableAMHandler = PostgresDataType(269) + /// `271` + public static let xid8Array = PostgresDataType(271) + /// `325` + public static let indexAMHandler = PostgresDataType(325) /// `600` public static let point = PostgresDataType(600) + /// `601` + public static let lseg = PostgresDataType(601) + /// `602` + public static let path = PostgresDataType(602) + /// `603` + public static let box = PostgresDataType(603) + /// `604` + public static let polygon = PostgresDataType(604) + /// `628` + public static let line = PostgresDataType(628) + /// `629` + public static let lineArray = PostgresDataType(629) + /// `650` + public static let cidr = PostgresDataType(650) + /// `651` + public static let cidrArray = PostgresDataType(651) /// `700` public static let float4 = PostgresDataType(700) /// `701` public static let float8 = PostgresDataType(701) + /// `705` + public static let unknown = PostgresDataType(705) + /// `718` + public static let circle = PostgresDataType(718) + /// `719` + public static let circleArray = PostgresDataType(719) + /// `774` + public static let macaddr8 = PostgresDataType(774) + /// `775` + public static let macaddr8Aray = PostgresDataType(775) /// `790` public static let money = PostgresDataType(790) + /// `791` + @available(*, deprecated, renamed: "moneyArray") + public static let _money = PostgresDataType(791) + public static let moneyArray = PostgresDataType(791) + /// `829` + public static let macaddr = PostgresDataType(829) + /// `869` + public static let inet = PostgresDataType(869) /// `1000` _bool public static let boolArray = PostgresDataType(1000) /// `1001` _bytea @@ -71,22 +134,52 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri public static let nameArray = PostgresDataType(1003) /// `1005` _int2 public static let int2Array = PostgresDataType(1005) + /// `1006` + public static let int2vectorArray = PostgresDataType(1006) /// `1007` _int4 public static let int4Array = PostgresDataType(1007) + /// `1008` + public static let regprocArray = PostgresDataType(1008) /// `1009` _text public static let textArray = PostgresDataType(1009) + /// `1010` + public static let tidArray = PostgresDataType(1010) + /// `1011` + public static let xidArray = PostgresDataType(1011) + /// `1012` + public static let cidArray = PostgresDataType(1012) + /// `1013` + public static let oidvectorArray = PostgresDataType(1013) + /// `1014` + public static let bpcharArray = PostgresDataType(1014) /// `1015` _varchar public static let varcharArray = PostgresDataType(1015) /// `1016` _int8 public static let int8Array = PostgresDataType(1016) /// `1017` _point public static let pointArray = PostgresDataType(1017) + /// `1018` + public static let lsegArray = PostgresDataType(1018) + /// `1019` + public static let pathArray = PostgresDataType(1019) + /// `1020` + public static let boxArray = PostgresDataType(1020) /// `1021` _float4 public static let float4Array = PostgresDataType(1021) /// `1022` _float8 public static let float8Array = PostgresDataType(1022) + /// `1027` + public static let polygonArray = PostgresDataType(1027) + /// `1028` + public static let oidArray = PostgresDataType(1018) + /// `1033` + public static let aclitem = PostgresDataType(1033) /// `1034` _aclitem public static let aclitemArray = PostgresDataType(1034) + /// `1040` + public static let macaddrArray = PostgresDataType(1040) + /// `1041` + public static let inetArray = PostgresDataType(1041) /// `1042` public static let bpchar = PostgresDataType(1042) /// `1043` @@ -101,28 +194,196 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri public static let timestampArray = PostgresDataType(1115) /// `1184` public static let timestamptz = PostgresDataType(1184) + /// `1185` + public static let timestamptzArray = PostgresDataType(1185) + /// `1186` + public static let interval = PostgresDataType(1186) + /// `1187` + public static let intervalArray = PostgresDataType(1187) + /// `1231` + public static let numericArray = PostgresDataType(1231) + /// `1263` + public static let cstringArray = PostgresDataType(1263) /// `1266` public static let timetz = PostgresDataType(1266) + /// `1270` + public static let timetzArray = PostgresDataType(1270) + /// `1560` + public static let bit = PostgresDataType(1560) + /// `1561` + public static let bitArray = PostgresDataType(1561) + /// `1562` + public static let varbit = PostgresDataType(1562) + /// `1563` + public static let varbitArray = PostgresDataType(1563) /// `1700` public static let numeric = PostgresDataType(1700) + /// `1790` + public static let refcursor = PostgresDataType(1790) + /// `2201` + public static let refcursorArray = PostgresDataType(2201) + /// `2202` + public static let regprocedure = PostgresDataType(2202) + /// `2203` + public static let regoper = PostgresDataType(2203) + /// `2204` + public static let regoperator = PostgresDataType(2204) + /// `2205` + public static let regclass = PostgresDataType(2205) + /// `2206` + public static let regtype = PostgresDataType(2206) + /// `2207` + public static let regprocedureArray = PostgresDataType(2207) + /// `2208` + public static let regoperArray = PostgresDataType(2208) + /// `2209` + public static let regoperatorArray = PostgresDataType(2209) + /// `2210` + public static let regclassArray = PostgresDataType(2210) + /// `2211` + public static let regtypeArray = PostgresDataType(2211) + /// `2249` + public static let record = PostgresDataType(2249) + /// `2275` + public static let cstring = PostgresDataType(2275) + /// `2276` + public static let any = PostgresDataType(2276) + /// `2277` + public static let anyarray = PostgresDataType(2277) /// `2278` public static let void = PostgresDataType(2278) + /// `2279` + public static let trigger = PostgresDataType(2279) + /// `2280` + public static let languageHandler = PostgresDataType(2280) + /// `2281` + public static let `internal` = PostgresDataType(2281) + /// `2283` + public static let anyelement = PostgresDataType(2283) + /// `2287` + public static let recordArray = PostgresDataType(2287) + /// `2776` + public static let anynonarray = PostgresDataType(2776) /// `2950` public static let uuid = PostgresDataType(2950) /// `2951` _uuid public static let uuidArray = PostgresDataType(2951) + /// `3115` + public static let fdwHandler = PostgresDataType(3115) + /// `3220` + public static let pgLSN = PostgresDataType(3220) + /// `3221` + public static let pgLSNArray = PostgresDataType(3221) + /// `3310` + public static let tsmHandler = PostgresDataType(3310) + /// `3500` + public static let anyenum = PostgresDataType(3500) + /// `3614` + public static let tsvector = PostgresDataType(3614) + /// `3615` + public static let tsquery = PostgresDataType(3615) + /// `3642` + public static let gtsvector = PostgresDataType(3642) + /// `3643` + public static let tsvectorArray = PostgresDataType(3643) + /// `3644` + public static let gtsvectorArray = PostgresDataType(3644) + /// `3645` + public static let tsqueryArray = PostgresDataType(3645) + /// `3734` + public static let regconfig = PostgresDataType(3734) + /// `3735` + public static let regconfigArray = PostgresDataType(3735) + /// `3769` + public static let regdictionary = PostgresDataType(3769) + /// `3770` + public static let regdictionaryArray = PostgresDataType(3770) /// `3802` public static let jsonb = PostgresDataType(3802) /// `3807` _jsonb public static let jsonbArray = PostgresDataType(3807) + /// `3831` + public static let anyrange = PostgresDataType(3831) + /// `3838` + public static let eventTrigger = PostgresDataType(3838) /// `3904` public static let int4Range = PostgresDataType(3904) /// `3905` _int4range public static let int4RangeArray = PostgresDataType(3905) + /// `3906` + public static let numrange = PostgresDataType(3906) + /// `3907` + public static let numrangeArray = PostgresDataType(3907) + /// `3908` + public static let tsrange = PostgresDataType(3908) + /// `3909` + public static let tsrangeArray = PostgresDataType(3909) + /// `3910` + public static let tstzrange = PostgresDataType(3910) + /// `3911` + public static let tstzrangeArray = PostgresDataType(3911) + /// `3912` + public static let daterange = PostgresDataType(3912) + /// `3913` + public static let daterangeArray = PostgresDataType(3913) /// `3926` public static let int8Range = PostgresDataType(3926) /// `3927` _int8range public static let int8RangeArray = PostgresDataType(3927) + /// `4072` + public static let jsonpath = PostgresDataType(4072) + /// `4073` + public static let jsonpathArray = PostgresDataType(4073) + /// `4089` + public static let regnamespace = PostgresDataType(4089) + /// `4090` + public static let regnamespaceArray = PostgresDataType(4090) + /// `4096` + public static let regrole = PostgresDataType(4096) + /// `4097` + public static let regroleArray = PostgresDataType(4097) + /// `4191` + public static let regcollation = PostgresDataType(4191) + /// `4192` + public static let regcollationArray = PostgresDataType(4192) + /// `4451` + public static let int4multirange = PostgresDataType(4451) + /// `4532` + public static let nummultirange = PostgresDataType(4532) + /// `4533` + public static let tsmultirange = PostgresDataType(4533) + /// `4534` + public static let tstzmultirange = PostgresDataType(4534) + /// `4535` + public static let datemultirange = PostgresDataType(4535) + /// `4536` + public static let int8multirange = PostgresDataType(4536) + /// `4537` + public static let anymultirange = PostgresDataType(4537) + /// `4538` + public static let anycompatiblemultirange = PostgresDataType(4538) + /// `5069` + public static let xid8 = PostgresDataType(5069) + /// `5077` + public static let anycompatible = PostgresDataType(5077) + /// `5078` + public static let anycompatiblearray = PostgresDataType(5078) + /// `5079` + public static let anycompatiblenonarray = PostgresDataType(5079) + /// `5080` + public static let anycompatiblerange = PostgresDataType(5080) + /// `6150` + public static let int4multirangeArray = PostgresDataType(6150) + /// `6151` + public static let nummultirangeArray = PostgresDataType(6151) + /// `6152` + public static let tsmultirangeArray = PostgresDataType(6152) + /// `6153` + public static let tstzmultirangeArray = PostgresDataType(6153) + /// `6155` + public static let datemultirangeArray = PostgresDataType(6155) + /// `6157` + public static let int8multirangeArray = PostgresDataType(6157) /// The raw data type code recognized by PostgreSQL. public var rawValue: UInt32 @@ -144,61 +405,246 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri /// Returns the known SQL name, if one exists. /// Note: This only supports a limited subset of all PSQL types and is meant for convenience only. + /// This list was manually generated. public var knownSQLName: String? { switch self { + case .null: return "NULL" case .bool: return "BOOLEAN" case .bytea: return "BYTEA" case .char: return "CHAR" case .name: return "NAME" case .int8: return "BIGINT" case .int2: return "SMALLINT" + case .int2vector: return "INT2VECTOR" case .int4: return "INTEGER" case .regproc: return "REGPROC" case .text: return "TEXT" case .oid: return "OID" + case .tid: return "TID" + case .xid: return "XID" + case .cid: return "CID" + case .oidvector: return "OIDVECTOR" + case .pgDDLCommand: return "PG_DDL_COMMAND" case .json: return "JSON" - case .pgNodeTree: return "PGNODETREE" + case .xml: return "XML" + case .xmlArray: return "XML[]" + case .jsonArray: return "JSON[]" + case .tableAMHandler: return "TABLE_AM_HANDLER" + case .xid8Array: return "XID8[]" + case .indexAMHandler: return "INDEX_AM_HANDLER" case .point: return "POINT" + case .lseg: return "LSEG" + case .path: return "PATH" + case .box: return "BOX" + case .polygon: return "POLYGON" + case .line: return "LINE" + case .lineArray: return "LINE[]" + case .cidr: return "CIDR" + case .cidrArray: return "CIDR[]" case .float4: return "REAL" case .float8: return "DOUBLE PRECISION" + case .circle: return "CIRCLE" + case .circleArray: return "CIRCLE[]" + case .macaddr8: return "MACADDR8" + case .macaddr8Aray: return "MACADDR8[]" case .money: return "MONEY" + case .moneyArray: return "MONEY[]" + case .macaddr: return "MACADDR" + case .inet: return "INET" case .boolArray: return "BOOLEAN[]" case .byteaArray: return "BYTEA[]" case .charArray: return "CHAR[]" case .nameArray: return "NAME[]" case .int2Array: return "SMALLINT[]" + case .int2vectorArray: return "INT2VECTOR[]" case .int4Array: return "INTEGER[]" + case .regprocArray: return "REGPROC[]" case .textArray: return "TEXT[]" + case .tidArray: return "TID[]" + case .xidArray: return "XID[]" + case .cidArray: return "CID[]" + case .oidvectorArray: return "OIDVECTOR[]" + case .bpcharArray: return "CHARACTER[]" case .varcharArray: return "VARCHAR[]" case .int8Array: return "BIGINT[]" case .pointArray: return "POINT[]" + case .lsegArray: return "LSEG[]" + case .pathArray: return "PATH[]" + case .boxArray: return "BOX[]" case .float4Array: return "REAL[]" case .float8Array: return "DOUBLE PRECISION[]" + case .polygonArray: return "POLYGON[]" + case .oidArray: return "OID[]" + case .aclitem: return "ACLITEM" case .aclitemArray: return "ACLITEM[]" - case .bpchar: return "BPCHAR" + case .macaddrArray: return "MACADDR[]" + case .inetArray: return "INET[]" + case .bpchar: return "CHARACTER" case .varchar: return "VARCHAR" case .date: return "DATE" case .time: return "TIME" case .timestamp: return "TIMESTAMP" - case .timestamptz: return "TIMESTAMPTZ" case .timestampArray: return "TIMESTAMP[]" + case .timestamptz: return "TIMESTAMPTZ" + case .timestamptzArray: return "TIMESTAMPTZ[]" + case .interval: return "INTERVAL" + case .intervalArray: return "INTERVAL[]" + case .numericArray: return "NUMERIC[]" + case .cstringArray: return "CSTRING[]" + case .timetz: return "TIMETZ" + case .timetzArray: return "TIMETZ[]" + case .bit: return "BIT" + case .bitArray: return "BIT[]" + case .varbit: return "VARBIT" + case .varbitArray: return "VARBIT[]" case .numeric: return "NUMERIC" + case .refcursor: return "REFCURSOR" + case .refcursorArray: return "REFCURSOR[]" + case .regprocedure: return "REGPROCEDURE" + case .regoper: return "REGOPER" + case .regoperator: return "REGOPERATOR" + case .regclass: return "REGCLASS" + case .regtype: return "REGTYPE" + case .regprocedureArray: return "REGPROCEDURE[]" + case .regoperArray: return "REGOPER[]" + case .regoperatorArray: return "REGOPERATOR[]" + case .regclassArray: return "REGCLASS[]" + case .regtypeArray: return "REGTYPE[]" + case .record: return "RECORD" + case .cstring: return "CSTRING" + case .any: return "ANY" + case .anyarray: return "ANYARRAY" case .void: return "VOID" + case .trigger: return "TRIGGER" + case .languageHandler: return "LANGUAGE_HANDLER" + case .`internal`: return "INTERNAL" + case .anyelement: return "ANYELEMENT" + case .recordArray: return "RECORD[]" + case .anynonarray: return "ANYNONARRAY" case .uuid: return "UUID" case .uuidArray: return "UUID[]" + case .fdwHandler: return "FDW_HANDLER" + case .pgLSN: return "PG_LSN" + case .pgLSNArray: return "PG_LSN[]" + case .tsmHandler: return "TSM_HANDLER" + case .anyenum: return "ANYENUM" + case .tsvector: return "TSVECTOR" + case .tsquery: return "TSQUERY" + case .gtsvector: return "GTSVECTOR" + case .tsvectorArray: return "TSVECTOR[]" + case .gtsvectorArray: return "GTSVECTOR[]" + case .tsqueryArray: return "TSQUERY[]" + case .regconfig: return "REGCONFIG" + case .regconfigArray: return "REGCONFIG[]" + case .regdictionary: return "REGDICTIONARY" + case .regdictionaryArray: return "REGDICTIONARY[]" case .jsonb: return "JSONB" case .jsonbArray: return "JSONB[]" + case .anyrange: return "ANYRANGE" + case .eventTrigger: return "EVENT_TRIGGER" case .int4Range: return "INT4RANGE" case .int4RangeArray: return "INT4RANGE[]" + case .numrange: return "NUMRANGE" + case .numrangeArray: return "NUMRANGE[]" + case .tsrange: return "TSRANGE" + case .tsrangeArray: return "TSRANGE[]" + case .tstzrange: return "TSTZRANGE" + case .tstzrangeArray: return "TSTZRANGE[]" + case .daterange: return "DATERANGE" + case .daterangeArray: return "DATERANGE[]" case .int8Range: return "INT8RANGE" case .int8RangeArray: return "INT8RANGE[]" + case .jsonpath: return "JSONPATH" + case .jsonpathArray: return "JSONPATH[]" + case .regnamespace: return "REGNAMESPACE" + case .regnamespaceArray: return "REGNAMESPACE[]" + case .regrole: return "REGROLE" + case .regroleArray: return "REGROLE[]" + case .regcollation: return "REGCOLLATION" + case .regcollationArray: return "REGCOLLATION[]" + case .int4multirange: return "INT4MULTIRANGE" + case .nummultirange: return "NUMMULTIRANGE" + case .tsmultirange: return "TSMULTIRANGE" + case .tstzmultirange: return "TSTZMULTIRANGE" + case .datemultirange: return "DATEMULTIRANGE" + case .int8multirange: return "INT8MULTIRANGE" + case .anymultirange: return "ANYMULTIRANGE" + case .anycompatiblemultirange: return "ANYCOMPATIBLEMULTIRANGE" + case .xid8: return "XID8" + case .anycompatible: return "ANYCOMPATIBLE" + case .anycompatiblearray: return "ANYCOMPATIBLEARRAY" + case .anycompatiblenonarray: return "ANYCOMPATIBLENONARRAY" + case .anycompatiblerange: return "ANYCOMPATIBLERANG" + case .int4multirangeArray: return "INT4MULTIRANGE[]" + case .nummultirangeArray: return "NUMMULTIRANGE[]" + case .tsmultirangeArray: return "TSMULTIRANGE[]" + case .tstzmultirangeArray: return "TSTZMULTIRANGE[]" + case .datemultirangeArray: return "DATEMULTIRANGE[]" + case .int8multirangeArray: return "INT8MULTIRANGE[]" default: return nil } } /// Returns the array type for this type if one is known. + /// + /// This list was manually generated. internal var arrayType: PostgresDataType? { switch self { + case .xml: return .xmlArray + case .json: return .jsonArray + case .xid8: return .xid8Array + case .line: return .lineArray + case .cidr: return .cidrArray + case .circle: return .circleArray + case .macaddr8Aray: return .macaddr8 + case .money: return .moneyArray + case .int2vector: return .int2vectorArray + case .regproc: return .regprocArray + case .tid: return .tidArray + case .xid: return .xidArray + case .cid: return .cidArray + case .oidvector: return .oidvectorArray + case .bpchar: return .bpcharArray + case .lseg: return .lsegArray + case .path: return .pathArray + case .box: return .boxArray + case .polygon: return .polygonArray + case .oid: return .oidArray + case .aclitem: return .aclitemArray + case .macaddr: return .macaddrArray + case .inet: return .inetArray + case .timestamptz: return .timestamptzArray + case .interval: return .intervalArray + case .numeric: return .numericArray + case .cstring: return .cstringArray + case .timetz: return .timetzArray + case .bit: return .bitArray + case .varbit: return .varbitArray + case .refcursor: return .refcursorArray + case .regprocedure: return .regprocedureArray + case .regoper: return .regoperArray + case .regoperator: return .regoperatorArray + case .regclass: return .regclassArray + case .regtype: return .regtypeArray + case .record: return .recordArray + case .pgLSN: return .pgLSNArray + case .tsvector: return .tsvectorArray + case .gtsvector: return .gtsvectorArray + case .tsquery: return .tsqueryArray + case .regconfig: return .regconfigArray + case .regdictionary: return .regdictionaryArray + case .numrange: return .numrangeArray + case .tsrange: return .tsrangeArray + case .daterange: return .daterangeArray + case .jsonpath: return .jsonpathArray + case .regnamespace: return .regnamespaceArray + case .regrole: return .regroleArray + case .regcollation: return .regcollationArray + case .int4multirange: return .int4multirangeArray + case .tsmultirange: return .tsmultirangeArray + case .tstzmultirange: return .tstzmultirangeArray + case .datemultirange: return .datemultirange + case .int8multirange: return .int8multirangeArray case .bool: return .boolArray case .bytea: return .byteaArray case .char: return .charArray @@ -221,8 +667,65 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri /// Returns the element type for this type if one is known. /// Returns nil if this is not an array type. + /// + /// This list was manually generated. internal var elementType: PostgresDataType? { switch self { + case .xmlArray: return .xml + case .jsonArray: return .json + case .xid8Array: return .xid8 + case .lineArray: return .line + case .cidrArray: return .cidr + case .circleArray: return .circle + case .macaddr8: return .macaddr8Aray + case .moneyArray: return .money + case .int2vectorArray: return .int2vector + case .regprocArray: return .regproc + case .tidArray: return .tid + case .xidArray: return .xid + case .cidArray: return .cid + case .oidvectorArray: return .oidvector + case .bpcharArray: return .bpchar + case .lsegArray: return .lseg + case .pathArray: return .path + case .boxArray: return .box + case .polygonArray: return .polygon + case .oidArray: return .oid + case .aclitemArray: return .aclitem + case .macaddrArray: return .macaddr + case .inetArray: return .inet + case .timestamptzArray: return .timestamptz + case .intervalArray: return .interval + case .numericArray: return .numeric + case .cstringArray: return .cstring + case .timetzArray: return .timetz + case .bitArray: return .bit + case .varbitArray: return .varbit + case .refcursorArray: return .refcursor + case .regprocedureArray: return .regprocedure + case .regoperArray: return .regoper + case .regoperatorArray: return .regoperator + case .regclassArray: return .regclass + case .regtypeArray: return .regtype + case .recordArray: return .record + case .pgLSNArray: return .pgLSN + case .tsvectorArray: return .tsvector + case .gtsvectorArray: return .gtsvector + case .tsqueryArray: return .tsquery + case .regconfigArray: return .regconfig + case .regdictionaryArray: return .regdictionary + case .numrangeArray: return .numrange + case .tsrangeArray: return .tsrange + case .daterangeArray: return .daterange + case .jsonpathArray: return .jsonpath + case .regnamespaceArray: return .regnamespace + case .regroleArray: return .regrole + case .regcollationArray: return .regcollation + case .int4multirangeArray: return .int4multirange + case .tsmultirangeArray: return .tsmultirange + case .tstzmultirangeArray: return .tstzmultirange + case .datemultirange: return .datemultirange + case .int8multirangeArray: return .int8multirange case .boolArray: return .bool case .byteaArray: return .bytea case .charArray: return .char @@ -245,16 +748,22 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri /// Returns the bound type for this type if one is known. /// Returns nil if this is not a range type. + /// + /// This list was manually generated. @usableFromInline internal var boundType: PostgresDataType? { switch self { case .int4Range: return .int4 case .int8Range: return .int8 + case .numrange: return .numeric + case .tsrange: return .timestamp + case .tstzrange: return .timestamptz + case .daterange: return .date default: return nil } } - /// See `CustomStringConvertible`. + /// See ``Swift/CustomStringConvertible/description``. public var description: String { return self.knownSQLName ?? "UNKNOWN \(self.rawValue)" } From 061a0836d7c1887e04a975d1d2eaa2ef5fd7dfab Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 9 Jun 2023 08:14:56 -0500 Subject: [PATCH 161/292] Add PSQLError debugDescription (#372) Co-authored-by: Fabian Fett --- Sources/PostgresNIO/New/PSQLError.swift | 118 +++++++++++++-- Sources/PostgresNIO/New/PostgresQuery.swift | 136 +++++++++++++++++- Tests/IntegrationTests/AsyncTests.swift | 4 +- .../New/PostgresErrorTests.swift | 59 +++++++- 4 files changed, 292 insertions(+), 25 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 08b6a01e..df7dd7c1 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -190,21 +190,13 @@ public struct PSQLError: Error { private final class Backing { fileprivate var code: Code - fileprivate var serverInfo: ServerInfo? - fileprivate var underlying: Error? - fileprivate var file: String? - fileprivate var line: Int? - fileprivate var query: PostgresQuery? - fileprivate var backendMessage: PostgresBackendMessage? - fileprivate var unsupportedAuthScheme: UnsupportedAuthScheme? - fileprivate var invalidCommandTag: String? init(code: Code) { @@ -224,10 +216,10 @@ public struct PSQLError: Error { } public struct ServerInfo { - public struct Field: Hashable, Sendable { + public struct Field: Hashable, Sendable, CustomStringConvertible { fileprivate let backing: PostgresBackendMessage.Field - private init(_ backing: PostgresBackendMessage.Field) { + fileprivate init(_ backing: PostgresBackendMessage.Field) { self.backing = backing } @@ -306,6 +298,47 @@ public struct PSQLError: Error { /// Routine: the name of the source-code routine reporting the error. public static let routine = Self(.routine) + + public var description: String { + switch self.backing { + case .localizedSeverity: + return "localizedSeverity" + case .severity: + return "severity" + case .sqlState: + return "sqlState" + case .message: + return "message" + case .detail: + return "detail" + case .hint: + return "hint" + case .position: + return "position" + case .internalPosition: + return "internalPosition" + case .internalQuery: + return "internalQuery" + case .locationContext: + return "locationContext" + case .schemaName: + return "schemaName" + case .tableName: + return "tableName" + case .columnName: + return "columnName" + case .dataTypeName: + return "dataTypeName" + case .constraintName: + return "constraintName" + case .file: + return "file" + case .line: + return "line" + case .routine: + return "routine" + } + } } let underlying: PostgresBackendMessage.ErrorResponse @@ -397,6 +430,65 @@ public struct PSQLError: Error { } } +extension PSQLError: CustomStringConvertible { + public var description: String { + // This may seem very odd... But we are afraid that users might accidentally send the + // unfiltered errors out to end-users. This may leak security relevant information. For this + // reason we overwrite the error description by default to this generic "Database error" + """ + PSQLError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """ + } +} + +extension PSQLError: CustomDebugStringConvertible { + public var debugDescription: String { + var result = #"PSQLError(code: \#(self.code)"# + + if let serverInfo = self.serverInfo?.underlying { + result.append(", serverInfo: [") + result.append( + serverInfo.fields + .sorted(by: { $0.key.rawValue < $1.key.rawValue }) + .map { "\(PSQLError.ServerInfo.Field($0.0)): \($0.1)" } + .joined(separator: ", ") + ) + result.append("]") + } + + if let backendMessage = self.backendMessage { + result.append(", backendMessage: \(String(reflecting: backendMessage))") + } + + if let unsupportedAuthScheme = self.unsupportedAuthScheme { + result.append(", unsupportedAuthScheme: \(unsupportedAuthScheme)") + } + + if let invalidCommandTag = self.invalidCommandTag { + result.append(", invalidCommandTag: \(invalidCommandTag)") + } + + if let underlying = self.underlying { + result.append(", underlying: \(String(reflecting: underlying))") + } + + if let file = self.file { + result.append(", triggeredFromRequestInFile: \(file)") + if let line = self.line { + result.append(", line: \(line)") + } + } + + if let query = self.query { + result.append(", query: \(String(reflecting: query))") + } + + result.append(")") + + return result + } +} + /// An error that may happen when a ``PostgresRow`` or ``PostgresCell`` is decoded to native Swift types. public struct PostgresDecodingError: Error, Equatable { public struct Code: Hashable, Error, CustomStringConvertible { @@ -490,7 +582,9 @@ extension PostgresDecodingError: CustomStringConvertible { // This may seem very odd... But we are afraid that users might accidentally send the // unfiltered errors out to end-users. This may leak security relevant information. For this // reason we overwrite the error description by default to this generic "Database error" - "Database error" + """ + PostgresDecodingError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """ } } @@ -504,7 +598,7 @@ extension PostgresDecodingError: CustomDebugStringConvertible { result.append(#", postgresType: \#(self.postgresType)"#) result.append(#", postgresFormat: \#(self.postgresFormat)"#) if let postgresData = self.postgresData { - result.append(#", postgresData: \#(postgresData.debugDescription)"#) // https://github.com/apple/swift-nio/pull/2418 + result.append(#", postgresData: \#(String(reflecting: postgresData))"#) } result.append(#", file: \#(self.file)"#) result.append(#", line: \#(self.line)"#) diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 1ba75050..381370e9 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -95,6 +95,20 @@ extension PostgresQuery { } } +extension PostgresQuery: CustomStringConvertible { + /// See ``Swift/CustomStringConvertible/description``. + public var description: String { + "\(self.sql) \(self.binds)" + } +} + +extension PostgresQuery: CustomDebugStringConvertible { + /// See ``Swift/CustomDebugStringConvertible/debugDescription``. + public var debugDescription: String { + "PostgresQuery(sql: \(String(describing: self.sql)), binds: \(String(reflecting: self.binds)))" + } +} + struct PSQLExecuteStatement { /// The statements name var name: String @@ -111,16 +125,19 @@ public struct PostgresBindings: Sendable, Hashable { var dataType: PostgresDataType @usableFromInline var format: PostgresFormat + @usableFromInline + var protected: Bool @inlinable - init(dataType: PostgresDataType, format: PostgresFormat) { + init(dataType: PostgresDataType, format: PostgresFormat, protected: Bool) { self.dataType = dataType self.format = format + self.protected = protected } @inlinable - init(value: Value) { - self.init(dataType: Value.psqlType, format: Value.psqlFormat) + init(value: Value, protected: Bool) { + self.init(dataType: Value.psqlType, format: Value.psqlFormat, protected: protected) } } @@ -147,7 +164,7 @@ public struct PostgresBindings: Sendable, Hashable { public mutating func appendNull() { self.bytes.writeInteger(-1, as: Int32.self) - self.metadata.append(.init(dataType: .null, format: .binary)) + self.metadata.append(.init(dataType: .null, format: .binary, protected: true)) } @inlinable @@ -156,7 +173,7 @@ public struct PostgresBindings: Sendable, Hashable { context: PostgresEncodingContext ) throws { try value.encodeRaw(into: &self.bytes, context: context) - self.metadata.append(.init(value: value)) + self.metadata.append(.init(value: value, protected: true)) } @inlinable @@ -165,7 +182,25 @@ public struct PostgresBindings: Sendable, Hashable { context: PostgresEncodingContext ) { value.encodeRaw(into: &self.bytes, context: context) - self.metadata.append(.init(value: value)) + self.metadata.append(.init(value: value, protected: true)) + } + + @inlinable + mutating func appendUnprotected( + _ value: Value, + context: PostgresEncodingContext + ) throws { + try value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value, protected: false)) + } + + @inlinable + mutating func appendUnprotected( + _ value: Value, + context: PostgresEncodingContext + ) { + value.encodeRaw(into: &self.bytes, context: context) + self.metadata.append(.init(value: value, protected: false)) } public mutating func append(_ postgresData: PostgresData) { @@ -176,6 +211,93 @@ public struct PostgresBindings: Sendable, Hashable { self.bytes.writeInteger(Int32(input.readableBytes)) self.bytes.writeBuffer(&input) } - self.metadata.append(.init(dataType: postgresData.type, format: .binary)) + self.metadata.append(.init(dataType: postgresData.type, format: .binary, protected: true)) + } +} + +extension PostgresBindings: CustomStringConvertible, CustomDebugStringConvertible { + /// See ``Swift/CustomStringConvertible/description``. + public var description: String { + """ + [\(zip(self.metadata, BindingsReader(buffer: self.bytes)) + .lazy.map({ Self.makeBindingPrintable(protected: $0.protected, type: $0.dataType, format: $0.format, buffer: $1) }) + .joined(separator: ", "))] + """ + } + + /// See ``Swift/CustomDebugStringConvertible/description``. + public var debugDescription: String { + """ + [\(zip(self.metadata, BindingsReader(buffer: self.bytes)) + .lazy.map({ Self.makeDebugDescription(protected: $0.protected, type: $0.dataType, format: $0.format, buffer: $1) }) + .joined(separator: ", "))] + """ + } + + private static func makeDebugDescription(protected: Bool, type: PostgresDataType, format: PostgresFormat, buffer: ByteBuffer?) -> String { + "(\(Self.makeBindingPrintable(protected: protected, type: type, format: format, buffer: buffer)); \(type); format: \(format))" + } + + private static func makeBindingPrintable(protected: Bool, type: PostgresDataType, format: PostgresFormat, buffer: ByteBuffer?) -> String { + if protected { + return "****" + } + + guard var buffer = buffer else { + return "null" + } + + do { + switch (type, format) { + case (.int4, _), (.int2, _), (.int8, _): + let number = try Int64.init(from: &buffer, type: type, format: format, context: .default) + return String(describing: number) + + case (.bool, _): + let bool = try Bool.init(from: &buffer, type: type, format: format, context: .default) + return String(describing: bool) + + case (.varchar, _), (.bpchar, _), (.text, _), (.name, _): + let value = try String.init(from: &buffer, type: type, format: format, context: .default) + return String(reflecting: value) // adds quotes + + default: + return "\(buffer.readableBytes) bytes" + } + } catch { + return "\(buffer.readableBytes) bytes" + } + } +} + +/// A small helper to inspect encoded bindings +private struct BindingsReader: Sequence { + typealias Element = Optional + + var buffer: ByteBuffer + + struct Iterator: IteratorProtocol { + typealias Element = Optional + private var buffer: ByteBuffer + + init(buffer: ByteBuffer) { + self.buffer = buffer + } + + mutating func next() -> Optional> { + guard let length = self.buffer.readInteger(as: Int32.self) else { + return .none + } + + if length < 0 { + return .some(.none) + } + + return .some(self.buffer.readSlice(length: Int(length))!) + } + } + + func makeIterator() -> Iterator { + Iterator(buffer: self.buffer) } } diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index c96c81f5..7a45c5c0 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -72,7 +72,7 @@ final class AsyncPostgresConnectionTests: XCTestCase { var counter = 0 for try await element in rows.decode((Int, String, String, String, String?, Int, Date, Date, String, String).self) { - XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "localhost") + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") XCTAssertEqual(element.2, env("POSTGRES_USER") ?? "test_username") XCTAssertEqual(element.8, query.sql) @@ -106,8 +106,6 @@ final class AsyncPostgresConnectionTests: XCTestCase { } catch { guard let error = error as? PSQLError else { return XCTFail("Unexpected error type") } - print(error) - XCTAssertEqual(error.code, .server) XCTAssertEqual(error.serverInfo?[.severity], "ERROR") } diff --git a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift index 639d6b5e..33df5439 100644 --- a/Tests/PostgresNIOTests/New/PostgresErrorTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresErrorTests.swift @@ -2,6 +2,55 @@ import XCTest import NIOCore +final class PSQLErrorTests: XCTestCase { + func testPostgresBindingsDescription() { + let testBinds1 = PostgresBindings(capacity: 0) + var testBinds2 = PostgresBindings(capacity: 1) + var testBinds3 = PostgresBindings(capacity: 2) + testBinds2.append(1, context: .default) + testBinds3.appendUnprotected(1, context: .default) + testBinds3.appendUnprotected("foo", context: .default) + testBinds3.append("secret", context: .default) + + XCTAssertEqual(String(describing: testBinds1), "[]") + XCTAssertEqual(String(reflecting: testBinds1), "[]") + XCTAssertEqual(String(describing: testBinds2), "[****]") + XCTAssertEqual(String(reflecting: testBinds2), "[(****; BIGINT; format: binary)]") + XCTAssertEqual(String(describing: testBinds3), #"[1, "foo", ****]"#) + XCTAssertEqual(String(reflecting: testBinds3), #"[(1; BIGINT; format: binary), ("foo"; TEXT; format: binary), (****; TEXT; format: binary)]"#) + } + + func testPostgresQueryDescription() { + let testBinds1 = PostgresBindings(capacity: 0) + var testBinds2 = PostgresBindings(capacity: 1) + testBinds2.append(1, context: .default) + let testQuery1 = PostgresQuery(unsafeSQL: "TEST QUERY") + let testQuery2 = PostgresQuery(unsafeSQL: "TEST QUERY", binds: testBinds1) + let testQuery3 = PostgresQuery(unsafeSQL: "TEST QUERY", binds: testBinds2) + + XCTAssertEqual(String(describing: testQuery1), "TEST QUERY []") + XCTAssertEqual(String(reflecting: testQuery1), "PostgresQuery(sql: TEST QUERY, binds: [])") + XCTAssertEqual(String(describing: testQuery2), "TEST QUERY []") + XCTAssertEqual(String(reflecting: testQuery2), "PostgresQuery(sql: TEST QUERY, binds: [])") + XCTAssertEqual(String(describing: testQuery3), "TEST QUERY [****]") + XCTAssertEqual(String(reflecting: testQuery3), "PostgresQuery(sql: TEST QUERY, binds: [(****; BIGINT; format: binary)])") + } + + func testPSQLErrorDescription() { + var error1 = PSQLError.server(.init(fields: [.localizedSeverity: "ERROR", .severity: "ERROR", .sqlState: "00000", .message: "Test message", .detail: "More test message", .hint: "It's a test, that's your hint", .position: "1", .schemaName: "testsch", .tableName: "testtab", .columnName: "testcol", .dataTypeName: "testtyp", .constraintName: "testcon", .file: #fileID, .line: "0", .routine: #function])) + var testBinds = PostgresBindings(capacity: 1) + testBinds.append(1, context: .default) + error1.query = .init(unsafeSQL: "TEST QUERY", binds: testBinds) + + XCTAssertEqual(String(describing: error1), """ + PSQLError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """) + XCTAssertEqual(String(reflecting: error1), """ + PSQLError(code: server, serverInfo: [sqlState: 00000, detail: More test message, file: PostgresNIOTests/PostgresErrorTests.swift, hint: It's a test, that's your hint, line: 0, message: Test message, position: 1, routine: testPSQLErrorDescription(), localizedSeverity: ERROR, severity: ERROR, columnName: testcol, dataTypeName: testtyp, constraintName: testcon, schemaName: testsch, tableName: testtab], query: PostgresQuery(sql: TEST QUERY, binds: [(****; BIGINT; format: binary)])) + """) + } +} + final class PostgresDecodingErrorTests: XCTestCase { func testPostgresDecodingErrorEquality() { let error1 = PostgresDecodingError( @@ -59,9 +108,13 @@ final class PostgresDecodingErrorTests: XCTestCase { ) // Plain description - XCTAssertEqual(String(describing: error1), "Database error") - XCTAssertEqual(String(describing: error2), "Database error") - + XCTAssertEqual(String(describing: error1), """ + PostgresDecodingError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """) + XCTAssertEqual(String(describing: error2), """ + PostgresDecodingError – Generic description to prevent accidental leakage of sensitive data. For debugging details, use `String(reflecting: error)`. + """) + // Extended debugDescription XCTAssertEqual(String(reflecting: error1), """ PostgresDecodingError(code: typeMismatch,\ From aa9273c06a0f42281635eaf0400aa024157c8fa9 Mon Sep 17 00:00:00 2001 From: Iceman Date: Thu, 20 Jul 2023 17:21:49 +0900 Subject: [PATCH 162/292] Use computed property to PostgresConnection.Configuration.TLS.disable for concurrency safe (#376) --- .../Connection/PostgresConnection+Configuration.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index 54eefc90..bc9bcfc2 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -13,7 +13,7 @@ extension PostgresConnection { // MARK: Initializers /// Do not try to create a TLS connection to the server. - public static var disable: Self = .init(base: .disable) + public static var disable: Self { .init(base: .disable) } /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. /// If the server does not support TLS, create an insecure connection. From f3587a586dc5d33b016da6b30d01bbad343c10af Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sat, 29 Jul 2023 04:01:07 -0500 Subject: [PATCH 163/292] Fix multiple warnings generated by the documentation build (#378) --- .github/workflows/api-docs.yml | 2 +- .github/workflows/test.yml | 18 ++++++++++-------- .../PostgresNIO/Data/PostgresDataType.swift | 2 +- Sources/PostgresNIO/Docs.docc/index.md | 4 ++-- Sources/PostgresNIO/Docs.docc/migrations.md | 2 +- Sources/PostgresNIO/New/PostgresQuery.swift | 8 ++++---- Sources/PostgresNIO/Utilities/Exports.swift | 2 +- 7 files changed, 20 insertions(+), 18 deletions(-) diff --git a/.github/workflows/api-docs.yml b/.github/workflows/api-docs.yml index 80291c6f..dc2e0634 100644 --- a/.github/workflows/api-docs.yml +++ b/.github/workflows/api-docs.yml @@ -11,4 +11,4 @@ jobs: with: package_name: postgres-nio modules: PostgresNIO - pathsToInvalidate: /postgresnio + pathsToInvalidate: /postgresnio/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 25374cf3..24821c77 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,13 +26,13 @@ jobs: container: ${{ matrix.container }} runs-on: ubuntu-latest steps: + - name: Note Swift version + if: ${{ contains(matrix.swiftver, 'nightly') }} + run: | + echo "SWIFT_PLATFORM=$(. /etc/os-release && echo "${ID}${VERSION_ID}")" >>"${GITHUB_ENV}" + echo "SWIFT_VERSION=$(cat /.swift_tag)" >>"${GITHUB_ENV}" - name: Display OS and Swift versions - shell: bash run: | - if [[ '${{ contains(matrix.container, 'nightly') }}' == 'true' ]]; then - SWIFT_PLATFORM="$(source /etc/os-release && echo "${ID}${VERSION_ID}")" SWIFT_VERSION="$(cat /.swift_tag)" - printf 'SWIFT_PLATFORM=%s\nSWIFT_VERSION=%s\n' "${SWIFT_PLATFORM}" "${SWIFT_VERSION}" >>"${GITHUB_ENV}" - fi printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package uses: actions/checkout@v3 @@ -144,6 +144,7 @@ jobs: POSTGRES_DB: 'postgres' POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} POSTGRES_SOCKET: '/tmp/.s.PGSQL.5432' + POSTGRES_VERSION: ${{ matrix.dbimage }} steps: - name: Select latest available Xcode uses: maxim-lobanov/setup-xcode@v1 @@ -151,9 +152,9 @@ jobs: xcode-version: ${{ matrix.xcode }} - name: Install Postgres, setup DB and auth, and wait for server start run: | - export PATH="$(brew --prefix)/opt/${{ matrix.dbimage }}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - (brew unlink postgresql || true) && brew install '${{ matrix.dbimage }}' && brew link --force '${{ matrix.dbimage }}' - initdb --locale=C --auth-host '${{ matrix.dbauth }}' -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") + export PATH="$(brew --prefix)/opt/${POSTGRES_VERSION}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test + (brew unlink postgresql || true) && brew install "${POSTGRES_VERSION}" && brew link --force "${POSTGRES_VERSION}" + initdb --locale=C --auth-host "${POSTGRES_HOST_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait timeout-minutes: 2 - name: Checkout code @@ -175,3 +176,4 @@ jobs: run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - name: API breaking changes run: swift package diagnose-api-breaking-changes origin/main + diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index ede60f47..f3ab4dca 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -763,7 +763,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri } } - /// See ``Swift/CustomStringConvertible/description``. + // See `CustomStringConvertible.description`. public var description: String { return self.knownSQLName ?? "UNKNOWN \(self.rawValue)" } diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md index e7363054..b4dc7e30 100644 --- a/Sources/PostgresNIO/Docs.docc/index.md +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -1,12 +1,12 @@ # ``PostgresNIO`` -🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. +🐘 Non-blocking, event-driven Swift client for PostgreSQL built on SwiftNIO. ## Overview Features: -- A ``PostgresConnection`` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server +- A ``PostgresConnection`` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server using [SwiftNIO]. - An async/await interface that supports backpressure - Automatic conversions between Swift primitive types and the Postgres wire format - Integrated with the Swift server ecosystem, including use of [SwiftLog]. diff --git a/Sources/PostgresNIO/Docs.docc/migrations.md b/Sources/PostgresNIO/Docs.docc/migrations.md index 33c8afd4..7185ba06 100644 --- a/Sources/PostgresNIO/Docs.docc/migrations.md +++ b/Sources/PostgresNIO/Docs.docc/migrations.md @@ -6,7 +6,7 @@ which use the ``PostgresRow/column(_:)`` API today. ## TLDR 1. Map your sequence of ``PostgresRow``s to ``PostgresRandomAccessRow``s. -2. Use the ``PostgresRandomAccessRow/subscript(name:)`` API to receive a ``PostgresCell`` +2. Use the ``PostgresRandomAccessRow/subscript(_:)-3facl`` API to receive a ``PostgresCell`` 3. Decode the ``PostgresCell`` into a Swift type using the ``PostgresCell/decode(_:file:line:)`` method. ```swift diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 381370e9..2e06e1d9 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -96,14 +96,14 @@ extension PostgresQuery { } extension PostgresQuery: CustomStringConvertible { - /// See ``Swift/CustomStringConvertible/description``. + // See `CustomStringConvertible.description`. public var description: String { "\(self.sql) \(self.binds)" } } extension PostgresQuery: CustomDebugStringConvertible { - /// See ``Swift/CustomDebugStringConvertible/debugDescription``. + // See `CustomDebugStringConvertible.debugDescription`. public var debugDescription: String { "PostgresQuery(sql: \(String(describing: self.sql)), binds: \(String(reflecting: self.binds)))" } @@ -216,7 +216,7 @@ public struct PostgresBindings: Sendable, Hashable { } extension PostgresBindings: CustomStringConvertible, CustomDebugStringConvertible { - /// See ``Swift/CustomStringConvertible/description``. + // See `CustomStringConvertible.description`. public var description: String { """ [\(zip(self.metadata, BindingsReader(buffer: self.bytes)) @@ -225,7 +225,7 @@ extension PostgresBindings: CustomStringConvertible, CustomDebugStringConvertibl """ } - /// See ``Swift/CustomDebugStringConvertible/description``. + // See `CustomDebugStringConvertible.description`. public var debugDescription: String { """ [\(zip(self.metadata, BindingsReader(buffer: self.bytes)) diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 204df50c..58e12891 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,4 +1,4 @@ -#if compiler(>=5.8) +#if swift(>=5.8) @_documentation(visibility: internal) @_exported import NIO @_documentation(visibility: internal) @_exported import NIOSSL From 718d154ad788b9e3fca73c83016a03d70d018dfb Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 3 Aug 2023 17:30:47 +0200 Subject: [PATCH 164/292] Crash fix: Multiple bad messages could trigger reentrancy issue (#379) If we receive multiple unexpected messages from the backend we can run into a reentrancy situation in which we still have unread messages in the incoming buffer after we have received `channelInactive`. This pr patches this crash. --- .../ConnectionStateMachine.swift | 26 ++--- .../New/PostgresChannelHandler.swift | 107 ++++++++++-------- .../New/PostgresChannelHandlerTests.swift | 39 ++++++- 3 files changed, 111 insertions(+), 61 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 563bb026..ba1e3c1f 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -928,7 +928,7 @@ struct ConnectionStateMachine { .forwardStreamComplete, .wait, .read: - preconditionFailure("Expecting only failure actions if an error happened") + preconditionFailure("Invalid state: \(self.state)") case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) case .failQuery(let queryContext, with: let error): @@ -951,7 +951,7 @@ struct ConnectionStateMachine { .succeedPreparedStatementCreation, .read, .wait: - preconditionFailure("Expecting only failure actions if an error happened") + preconditionFailure("Invalid state: \(self.state)") case .failPreparedStatementCreation(let preparedStatementContext, with: let error): return .failPreparedStatementCreation(preparedStatementContext, with: error, cleanupContext: cleanupContext) } @@ -970,22 +970,20 @@ struct ConnectionStateMachine { .succeedClose, .read, .wait: - preconditionFailure("Expecting only failure actions if an error happened") + preconditionFailure("Invalid state: \(self.state)") case .failClose(let closeCommandContext, with: let error): return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) } - case .error: - // TBD: this is an interesting case. why would this case happen? - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - return .closeConnectionAndCleanup(cleanupContext) - - case .closing: - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - return .closeConnectionAndCleanup(cleanupContext) - case .closed: - preconditionFailure("How can an error occur if the connection is already closed?") + case .error, .closing, .closed: + // We might run into this case because of reentrancy. For example: After we received an + // backend unexpected message, that we read of the wire, we bring this connection into + // the error state and will try to close the connection. However the server might have + // send further follow up messages. In those cases we will run into this method again + // and again. We should just ignore those events. + return .wait + case .modifying: - preconditionFailure("Invalid state") + preconditionFailure("Invalid state: \(self.state)") } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 84f07d47..fdb6a443 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -84,6 +84,17 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } func channelInactive(context: ChannelHandlerContext) { + do { + try self.decoder.finishProcessing(seenEOF: true) { message in + self.handleMessage(message, context: context) + } + } catch let error as PostgresMessageDecodingError { + let action = self.state.errorHappened(.messageDecodingFailure(error)) + self.run(action, with: context) + } catch { + preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.") + } + self.logger.trace("Channel inactive.") let action = self.state.closed() self.run(action, with: context) @@ -100,51 +111,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { do { try self.decoder.process(buffer: buffer) { message in - self.logger.trace("Backend message received", metadata: [.message: "\(message)"]) - let action: ConnectionStateMachine.ConnectionAction - - switch message { - case .authentication(let authentication): - action = self.state.authenticationMessageReceived(authentication) - case .backendKeyData(let keyData): - action = self.state.backendKeyDataReceived(keyData) - case .bindComplete: - action = self.state.bindCompleteReceived() - case .closeComplete: - action = self.state.closeCompletedReceived() - case .commandComplete(let commandTag): - action = self.state.commandCompletedReceived(commandTag) - case .dataRow(let dataRow): - action = self.state.dataRowReceived(dataRow) - case .emptyQueryResponse: - action = self.state.emptyQueryResponseReceived() - case .error(let errorResponse): - action = self.state.errorReceived(errorResponse) - case .noData: - action = self.state.noDataReceived() - case .notice(let noticeResponse): - action = self.state.noticeReceived(noticeResponse) - case .notification(let notification): - action = self.state.notificationReceived(notification) - case .parameterDescription(let parameterDescription): - action = self.state.parameterDescriptionReceived(parameterDescription) - case .parameterStatus(let parameterStatus): - action = self.state.parameterStatusReceived(parameterStatus) - case .parseComplete: - action = self.state.parseCompleteReceived() - case .portalSuspended: - action = self.state.portalSuspendedReceived() - case .readyForQuery(let transactionState): - action = self.state.readyForQueryReceived(transactionState) - case .rowDescription(let rowDescription): - action = self.state.rowDescriptionReceived(rowDescription) - case .sslSupported: - action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes) - case .sslUnsupported: - action = self.state.sslUnsupportedReceived() - } - - self.run(action, with: context) + self.handleMessage(message, context: context) } } catch let error as PostgresMessageDecodingError { let action = self.state.errorHappened(.messageDecodingFailure(error)) @@ -153,7 +120,55 @@ final class PostgresChannelHandler: ChannelDuplexHandler { preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.") } } - + + private func handleMessage(_ message: PostgresBackendMessage, context: ChannelHandlerContext) { + self.logger.trace("Backend message received", metadata: [.message: "\(message)"]) + let action: ConnectionStateMachine.ConnectionAction + + switch message { + case .authentication(let authentication): + action = self.state.authenticationMessageReceived(authentication) + case .backendKeyData(let keyData): + action = self.state.backendKeyDataReceived(keyData) + case .bindComplete: + action = self.state.bindCompleteReceived() + case .closeComplete: + action = self.state.closeCompletedReceived() + case .commandComplete(let commandTag): + action = self.state.commandCompletedReceived(commandTag) + case .dataRow(let dataRow): + action = self.state.dataRowReceived(dataRow) + case .emptyQueryResponse: + action = self.state.emptyQueryResponseReceived() + case .error(let errorResponse): + action = self.state.errorReceived(errorResponse) + case .noData: + action = self.state.noDataReceived() + case .notice(let noticeResponse): + action = self.state.noticeReceived(noticeResponse) + case .notification(let notification): + action = self.state.notificationReceived(notification) + case .parameterDescription(let parameterDescription): + action = self.state.parameterDescriptionReceived(parameterDescription) + case .parameterStatus(let parameterStatus): + action = self.state.parameterStatusReceived(parameterStatus) + case .parseComplete: + action = self.state.parseCompleteReceived() + case .portalSuspended: + action = self.state.portalSuspendedReceived() + case .readyForQuery(let transactionState): + action = self.state.readyForQueryReceived(transactionState) + case .rowDescription(let rowDescription): + action = self.state.rowDescriptionReceived(rowDescription) + case .sslSupported: + action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes) + case .sslUnsupported: + action = self.state.sslUnsupportedReceived() + } + + self.run(action, with: context) + } + func channelReadComplete(context: ChannelHandlerContext) { let action = self.state.channelReadComplete() self.run(action, with: context) diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 7ab0ce30..d76b8223 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -198,7 +198,44 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(message, .password(.init(value: password))) } - + + func testHandlerThatSendsMultipleWrongMessages() { + let config = self.testConnectionConfiguration() + let handler = PostgresChannelHandler(configuration: config, configureSSLCallback: nil) + let embedded = EmbeddedChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + handler + ]) + + var maybeMessage: PostgresFrontendMessage? + XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) + XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) + guard case .startup(let startup) = maybeMessage else { + return XCTFail("Unexpected message") + } + + XCTAssertEqual(startup.parameters.user, config.username) + XCTAssertEqual(startup.parameters.database, config.database) + XCTAssertEqual(startup.parameters.options, nil) + XCTAssertEqual(startup.parameters.replication, .false) + + var buffer = ByteBuffer() + buffer.writeMultipleIntegers(UInt8(ascii: "R"), UInt32(8), Int32(0)) + buffer.writeMultipleIntegers(UInt8(ascii: "K"), UInt32(12), Int32(1234), Int32(5678)) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + XCTAssertNoThrow(try embedded.writeInbound(buffer)) + XCTAssertTrue(embedded.isActive) + + buffer.clear() + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I")) + + XCTAssertThrowsError(try embedded.writeInbound(buffer)) + XCTAssertFalse(embedded.isActive) + } + // MARK: Helpers func testConnectionConfiguration( From 4fd297db09ea09c6007b4abdec056f5f5387bb27 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 4 Aug 2023 22:55:49 +0200 Subject: [PATCH 165/292] PostgresFrontendMessage: refactor encoding (#381) --- .../New/BufferedMessageEncoder.swift | 35 --- Sources/PostgresNIO/New/Messages/Bind.swift | 45 ---- Sources/PostgresNIO/New/Messages/Cancel.swift | 21 -- Sources/PostgresNIO/New/Messages/Close.swift | 20 -- .../PostgresNIO/New/Messages/Describe.swift | 21 -- .../PostgresNIO/New/Messages/Execute.swift | 23 -- Sources/PostgresNIO/New/Messages/Parse.swift | 26 --- .../PostgresNIO/New/Messages/Password.swift | 13 -- .../New/Messages/SASLInitialResponse.swift | 28 --- .../New/Messages/SASLResponse.swift | 19 -- .../PostgresNIO/New/Messages/SSLRequest.swift | 21 -- .../PostgresNIO/New/Messages/Startup.swift | 40 +--- .../New/PSQLFrontendMessageEncoder.swift | 85 -------- .../New/PostgresChannelHandler.swift | 113 +++++----- .../New/PostgresFrontendMessage.swift | 94 +++++++- .../New/PostgresFrontendMessageEncoder.swift | 205 ++++++++++++++++++ .../PSQLFrontendMessageDecoder.swift | 2 +- .../New/Messages/BindTests.swift | 12 +- .../New/Messages/CancelTests.swift | 15 +- .../New/Messages/CloseTests.swift | 20 +- .../New/Messages/DescribeTests.swift | 18 +- .../New/Messages/ExecuteTests.swift | 9 +- .../New/Messages/ParseTests.swift | 39 ++-- .../New/Messages/PasswordTests.swift | 8 +- .../Messages/SASLInitialResponseTests.swift | 37 ++-- .../New/Messages/SASLResponseTests.swift | 26 +-- .../New/Messages/SSLRequestTests.swift | 12 +- .../New/Messages/StartupTests.swift | 11 +- .../New/PSQLFrontendMessageTests.swift | 24 +- .../New/PostgresChannelHandlerTests.swift | 20 +- 30 files changed, 464 insertions(+), 598 deletions(-) delete mode 100644 Sources/PostgresNIO/New/BufferedMessageEncoder.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Bind.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Cancel.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Close.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Describe.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Execute.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Parse.swift delete mode 100644 Sources/PostgresNIO/New/Messages/Password.swift delete mode 100644 Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift delete mode 100644 Sources/PostgresNIO/New/Messages/SASLResponse.swift delete mode 100644 Sources/PostgresNIO/New/Messages/SSLRequest.swift delete mode 100644 Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift create mode 100644 Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift diff --git a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift b/Sources/PostgresNIO/New/BufferedMessageEncoder.swift deleted file mode 100644 index f202fcff..00000000 --- a/Sources/PostgresNIO/New/BufferedMessageEncoder.swift +++ /dev/null @@ -1,35 +0,0 @@ -import NIOCore - -struct BufferedMessageEncoder { - private enum State { - case flushed - case writable - } - - private var buffer: ByteBuffer - private var state: State = .writable - private var encoder: PSQLFrontendMessageEncoder - - init(buffer: ByteBuffer, encoder: PSQLFrontendMessageEncoder) { - self.buffer = buffer - self.encoder = encoder - } - - mutating func encode(_ message: PostgresFrontendMessage) { - switch self.state { - case .flushed: - self.state = .writable - self.buffer.clear() - - case .writable: - break - } - - self.encoder.encode(data: message, out: &self.buffer) - } - - mutating func flush() -> ByteBuffer { - self.state = .flushed - return self.buffer - } -} diff --git a/Sources/PostgresNIO/New/Messages/Bind.swift b/Sources/PostgresNIO/New/Messages/Bind.swift deleted file mode 100644 index 898018d4..00000000 --- a/Sources/PostgresNIO/New/Messages/Bind.swift +++ /dev/null @@ -1,45 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Bind: PSQLMessagePayloadEncodable, Equatable { - /// The name of the destination portal (an empty string selects the unnamed portal). - var portalName: String - - /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). - var preparedStatementName: String - - /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. - var bind: PostgresBindings - - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.portalName) - buffer.writeNullTerminatedString(self.preparedStatementName) - - // The number of parameter format codes that follow (denoted C below). This can be - // zero to indicate that there are no parameters or that the parameters all use the - // default format (text); or one, in which case the specified format code is applied - // to all parameters; or it can equal the actual number of parameters. - buffer.writeInteger(UInt16(self.bind.count)) - - // The parameter format codes. Each must presently be zero (text) or one (binary). - self.bind.metadata.forEach { - buffer.writeInteger($0.format.rawValue) - } - - buffer.writeInteger(UInt16(self.bind.count)) - - var parametersCopy = self.bind.bytes - buffer.writeBuffer(¶metersCopy) - - // The number of result-column format codes that follow (denoted R below). This can be - // zero to indicate that there are no result columns or that the result columns should - // all use the default format (text); or one, in which case the specified format code - // is applied to all result columns (if any); or it can equal the actual number of - // result columns of the query. - buffer.writeInteger(1, as: Int16.self) - // The result-column format codes. Each must presently be zero (text) or one (binary). - buffer.writeInteger(PostgresFormat.binary.rawValue, as: Int16.self) - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Cancel.swift b/Sources/PostgresNIO/New/Messages/Cancel.swift deleted file mode 100644 index 2f29d239..00000000 --- a/Sources/PostgresNIO/New/Messages/Cancel.swift +++ /dev/null @@ -1,21 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Cancel: PSQLMessagePayloadEncodable, Equatable { - /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, - /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same - /// as any protocol version number.) - let cancelRequestCode: Int32 = 80877102 - - /// The process ID of the target backend. - let processID: Int32 - - /// The secret key for the target backend. - let secretKey: Int32 - - func encode(into buffer: inout ByteBuffer) { - buffer.writeMultipleIntegers(self.cancelRequestCode, self.processID, self.secretKey) - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Close.swift b/Sources/PostgresNIO/New/Messages/Close.swift deleted file mode 100644 index 7f038f94..00000000 --- a/Sources/PostgresNIO/New/Messages/Close.swift +++ /dev/null @@ -1,20 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - enum Close: PSQLMessagePayloadEncodable, Equatable { - case preparedStatement(String) - case portal(String) - - func encode(into buffer: inout ByteBuffer) { - switch self { - case .preparedStatement(let name): - buffer.writeInteger(UInt8(ascii: "S")) - buffer.writeNullTerminatedString(name) - case .portal(let name): - buffer.writeInteger(UInt8(ascii: "P")) - buffer.writeNullTerminatedString(name) - } - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Describe.swift b/Sources/PostgresNIO/New/Messages/Describe.swift deleted file mode 100644 index 76167d32..00000000 --- a/Sources/PostgresNIO/New/Messages/Describe.swift +++ /dev/null @@ -1,21 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - enum Describe: PSQLMessagePayloadEncodable, Equatable { - - case preparedStatement(String) - case portal(String) - - func encode(into buffer: inout ByteBuffer) { - switch self { - case .preparedStatement(let name): - buffer.writeInteger(UInt8(ascii: "S")) - buffer.writeNullTerminatedString(name) - case .portal(let name): - buffer.writeInteger(UInt8(ascii: "P")) - buffer.writeNullTerminatedString(name) - } - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Execute.swift b/Sources/PostgresNIO/New/Messages/Execute.swift deleted file mode 100644 index 17646484..00000000 --- a/Sources/PostgresNIO/New/Messages/Execute.swift +++ /dev/null @@ -1,23 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Execute: PSQLMessagePayloadEncodable, Equatable { - /// The name of the portal to execute (an empty string selects the unnamed portal). - let portalName: String - - /// Maximum number of rows to return, if portal contains a query that returns rows (ignored otherwise). Zero denotes “no limit”. - let maxNumberOfRows: Int32 - - init(portalName: String, maxNumberOfRows: Int32 = 0) { - self.portalName = portalName - self.maxNumberOfRows = maxNumberOfRows - } - - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.portalName) - buffer.writeInteger(self.maxNumberOfRows) - } - } - -} diff --git a/Sources/PostgresNIO/New/Messages/Parse.swift b/Sources/PostgresNIO/New/Messages/Parse.swift deleted file mode 100644 index 9d3cfa0b..00000000 --- a/Sources/PostgresNIO/New/Messages/Parse.swift +++ /dev/null @@ -1,26 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Parse: PSQLMessagePayloadEncodable, Equatable { - /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). - let preparedStatementName: String - - /// The query string to be parsed. - let query: String - - /// The number of parameter data types specified (can be zero). Note that this is not an indication of the number of parameters that might appear in the query string, only the number that the frontend wants to prespecify types for. - let parameters: [PostgresDataType] - - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.preparedStatementName) - buffer.writeNullTerminatedString(self.query) - buffer.writeInteger(UInt16(self.parameters.count)) - - self.parameters.forEach { dataType in - buffer.writeInteger(dataType.rawValue) - } - } - } - -} diff --git a/Sources/PostgresNIO/New/Messages/Password.swift b/Sources/PostgresNIO/New/Messages/Password.swift deleted file mode 100644 index 81d7ab30..00000000 --- a/Sources/PostgresNIO/New/Messages/Password.swift +++ /dev/null @@ -1,13 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct Password: PSQLMessagePayloadEncodable, Equatable { - let value: String - - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(value) - } - } - -} diff --git a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift b/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift deleted file mode 100644 index 73db9332..00000000 --- a/Sources/PostgresNIO/New/Messages/SASLInitialResponse.swift +++ /dev/null @@ -1,28 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct SASLInitialResponse: PSQLMessagePayloadEncodable, Equatable { - - let saslMechanism: String - let initialData: [UInt8] - - /// Creates a new `SSLRequest`. - init(saslMechanism: String, initialData: [UInt8]) { - self.saslMechanism = saslMechanism - self.initialData = initialData - } - - /// Serializes this message into a byte buffer. - func encode(into buffer: inout ByteBuffer) { - buffer.writeNullTerminatedString(self.saslMechanism) - - if self.initialData.count > 0 { - buffer.writeInteger(Int32(self.initialData.count)) - buffer.writeBytes(self.initialData) - } else { - buffer.writeInteger(Int32(-1)) - } - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/SASLResponse.swift b/Sources/PostgresNIO/New/Messages/SASLResponse.swift deleted file mode 100644 index a6709dcd..00000000 --- a/Sources/PostgresNIO/New/Messages/SASLResponse.swift +++ /dev/null @@ -1,19 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - - struct SASLResponse: PSQLMessagePayloadEncodable, Equatable { - - let data: [UInt8] - - /// Creates a new `SSLRequest`. - init(data: [UInt8]) { - self.data = data - } - - /// Serializes this message into a byte buffer. - func encode(into buffer: inout ByteBuffer) { - buffer.writeBytes(self.data) - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/SSLRequest.swift b/Sources/PostgresNIO/New/Messages/SSLRequest.swift deleted file mode 100644 index 6f9c45a3..00000000 --- a/Sources/PostgresNIO/New/Messages/SSLRequest.swift +++ /dev/null @@ -1,21 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - /// A message asking the PostgreSQL server if TLS is supported - /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 - struct SSLRequest: PSQLMessagePayloadEncodable, Equatable { - /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, - /// and 5679 in the least significant 16 bits. - let code: Int32 - - /// Creates a new `SSLRequest`. - init() { - self.code = 80877103 - } - - /// Serializes this message into a byte buffer. - func encode(into buffer: inout ByteBuffer) { - buffer.writeInteger(self.code) - } - } -} diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift index f7da2127..16d23e09 100644 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ b/Sources/PostgresNIO/New/Messages/Startup.swift @@ -1,13 +1,14 @@ import NIOCore extension PostgresFrontendMessage { - struct Startup: PSQLMessagePayloadEncodable, Equatable { + struct Startup: Hashable { + static let versionThree: Int32 = 0x00_03_00_00 /// Creates a `Startup` with "3.0" as the protocol version. static func versionThree(parameters: Parameters) -> Startup { - return .init(protocolVersion: 0x00_03_00_00, parameters: parameters) + return .init(protocolVersion: Self.versionThree, parameters: parameters) } - + /// The protocol version number. The most significant 16 bits are the major /// version number (3 for the protocol described here). The least significant /// 16 bits are the minor version number (0 for the protocol described here). @@ -16,7 +17,7 @@ extension PostgresFrontendMessage { /// The protocol version number is followed by one or more pairs of parameter /// name and value strings. A zero byte is required as a terminator after /// the last name/value pair. `user` is required, others are optional. - struct Parameters: Equatable { + struct Parameters: Hashable { enum Replication { case `true` case `false` @@ -47,36 +48,5 @@ extension PostgresFrontendMessage { self.protocolVersion = protocolVersion self.parameters = parameters } - - /// Serializes this message into a byte buffer. - func encode(into buffer: inout ByteBuffer) { - buffer.writeInteger(self.protocolVersion) - buffer.writeNullTerminatedString("user") - buffer.writeNullTerminatedString(self.parameters.user) - - if let database = self.parameters.database { - buffer.writeNullTerminatedString("database") - buffer.writeNullTerminatedString(database) - } - - if let options = self.parameters.options { - buffer.writeNullTerminatedString("options") - buffer.writeNullTerminatedString(options) - } - - switch self.parameters.replication { - case .database: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("replication") - case .true: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("true") - case .false: - break - } - - buffer.writeInteger(UInt8(0)) - } } - } diff --git a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift deleted file mode 100644 index 24155d84..00000000 --- a/Sources/PostgresNIO/New/PSQLFrontendMessageEncoder.swift +++ /dev/null @@ -1,85 +0,0 @@ -import NIOCore - -struct PSQLFrontendMessageEncoder: MessageToByteEncoder { - typealias OutboundIn = PostgresFrontendMessage - - init() {} - - func encode(data message: PostgresFrontendMessage, out buffer: inout ByteBuffer) { - switch message { - case .bind(let bind): - buffer.writeInteger(message.id.rawValue) - let startIndex = buffer.writerIndex - buffer.writeInteger(Int32(0)) // placeholder for length - bind.encode(into: &buffer) - let length = Int32(buffer.writerIndex - startIndex) - buffer.setInteger(length, at: startIndex) - - case .cancel(let cancel): - // cancel requests don't have an identifier - self.encode(payload: cancel, into: &buffer) - - case .close(let close): - self.encode(messageID: message.id, payload: close, into: &buffer) - - case .describe(let describe): - self.encode(messageID: message.id, payload: describe, into: &buffer) - - case .execute(let execute): - self.encode(messageID: message.id, payload: execute, into: &buffer) - - case .flush: - self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) - - case .parse(let parse): - self.encode(messageID: message.id, payload: parse, into: &buffer) - - case .password(let password): - self.encode(messageID: message.id, payload: password, into: &buffer) - - case .saslInitialResponse(let saslInitialResponse): - self.encode(messageID: message.id, payload: saslInitialResponse, into: &buffer) - - case .saslResponse(let saslResponse): - self.encode(messageID: message.id, payload: saslResponse, into: &buffer) - - case .sslRequest(let request): - // sslRequests don't have an identifier - self.encode(payload: request, into: &buffer) - - case .startup(let startup): - // startup requests don't have an identifier - self.encode(payload: startup, into: &buffer) - - case .sync: - self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) - - case .terminate: - self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer) - } - } - - private struct EmptyPayload: PSQLMessagePayloadEncodable { - func encode(into buffer: inout ByteBuffer) {} - } - - private func encode( - messageID: PostgresFrontendMessage.ID, - payload: Payload, - into buffer: inout ByteBuffer) - { - buffer.psqlWriteFrontendMessageID(messageID) - self.encode(payload: payload, into: &buffer) - } - - private func encode( - payload: Payload, - into buffer: inout ByteBuffer) - { - let startIndex = buffer.writerIndex - buffer.writeInteger(Int32(0)) // placeholder for length - payload.encode(into: &buffer) - let length = Int32(buffer.writerIndex - startIndex) - buffer.setInteger(length, at: startIndex) - } -} diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index fdb6a443..09feb521 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -21,7 +21,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private var handlerContext: ChannelHandlerContext? private var rowStream: PSQLRowStream? private var decoder: NIOSingleStepByteToMessageProcessor - private var encoder: BufferedMessageEncoder! + private var encoder: PostgresFrontendMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? @@ -58,10 +58,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { func handlerAdded(context: ChannelHandlerContext) { self.handlerContext = context - self.encoder = BufferedMessageEncoder( - buffer: context.channel.allocator.buffer(capacity: 256), - encoder: PSQLFrontendMessageEncoder() - ) + self.encoder = PostgresFrontendMessageEncoder(buffer: context.channel.allocator.buffer(capacity: 256)) if context.channel.isActive { self.connected(context: context) @@ -239,19 +236,19 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - self.encoder.encode(.startup(.versionThree(parameters: authContext.toStartupParameters()))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.startup(authContext.toStartupParameters()) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSSLRequest: - self.encoder.encode(.sslRequest(.init())) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.ssl() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendPasswordMessage(let mode, let authContext): self.sendPasswordMessage(mode: mode, authContext: authContext, context: context) case .sendSaslInitialResponse(let name, let initialResponse): - self.encoder.encode(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.saslInitialResponse(mechanism: name, bytes: initialResponse) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSaslResponse(let bytes): - self.encoder.encode(.saslResponse(.init(data: bytes))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.saslResponse(bytes) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .closeConnectionAndCleanup(let cleanupContext): self.closeConnectionAndCleanup(cleanupContext, context: context) case .fireChannelInactive: @@ -315,8 +312,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // The normal, graceful termination procedure is that the frontend sends a Terminate // message and immediately closes the connection. On receipt of this message, the // backend closes the connection and terminates. - self.encoder.encode(.terminate) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.terminate() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } context.close(mode: .all, promise: promise) case .succeedPreparedStatementCreation(let preparedContext, with: let rowDescription): @@ -381,89 +378,79 @@ final class PostgresChannelHandler: ChannelDuplexHandler { hash2.append(salt.3) let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest() - self.encoder.encode(.password(.init(value: hash))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.password(hash.utf8) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .cleartext: - self.encoder.encode(.password(.init(value: authContext.password ?? ""))) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.password((authContext.password ?? "").utf8) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } } private func sendCloseAndSyncMessage(_ sendClose: CloseTarget, context: ChannelHandlerContext) { switch sendClose { case .preparedStatement(let name): - self.encoder.encode(.close(.preparedStatement(name))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.closePreparedStatement(name) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .portal(let name): - self.encoder.encode(.close(.portal(name))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.closePortal(name) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } } private func sendParseDecribeAndSyncMessage( statementName: String, query: String, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") - let parse = PostgresFrontendMessage.Parse( - preparedStatementName: statementName, - query: query, - parameters: []) - self.encoder.encode(.parse(parse)) - self.encoder.encode(.describe(.preparedStatement(statementName))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + self.encoder.parse(preparedStatementName: statementName, query: query, parameters: []) + self.encoder.describePreparedStatement(statementName) + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } private func sendBindExecuteAndSyncMessage( executeStatement: PSQLExecuteStatement, context: ChannelHandlerContext ) { - let bind = PostgresFrontendMessage.Bind( + self.encoder.bind( portalName: "", preparedStatementName: executeStatement.name, - bind: executeStatement.binds) - - self.encoder.encode(.bind(bind)) - self.encoder.encode(.execute(.init(portalName: ""))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + bind: executeStatement.binds + ) + self.encoder.execute(portalName: "") + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } private func sendParseDescribeBindExecuteAndSyncMessage( query: PostgresQuery, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") let unnamedStatementName = "" - let parse = PostgresFrontendMessage.Parse( + self.encoder.parse( preparedStatementName: unnamedStatementName, query: query.sql, - parameters: query.binds.metadata.map(\.dataType)) - let bind = PostgresFrontendMessage.Bind( - portalName: "", - preparedStatementName: unnamedStatementName, - bind: query.binds) - - self.encoder.encode(.parse(parse)) - self.encoder.encode(.describe(.preparedStatement(""))) - self.encoder.encode(.bind(bind)) - self.encoder.encode(.execute(.init(portalName: ""))) - self.encoder.encode(.sync) - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()), promise: nil) + parameters: query.binds.metadata.lazy.map(\.dataType) + ) + self.encoder.describePreparedStatement(unnamedStatementName) + self.encoder.bind(portalName: "", preparedStatementName: unnamedStatementName, bind: query.binds) + self.encoder.execute(portalName: "") + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } private func succeedQueryWithRowStream( _ queryContext: ExtendedQueryContext, columns: [RowDescription.Column], - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { let rows = PSQLRowStream( rowDescription: columns, queryContext: queryContext, @@ -477,8 +464,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func succeedQueryWithoutRowStream( _ queryContext: ExtendedQueryContext, commandTag: String, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { let rows = PSQLRowStream( rowDescription: [], queryContext: queryContext, @@ -490,8 +477,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func closeConnectionAndCleanup( _ cleanup: ConnectionStateMachine.ConnectionAction.CleanUpContext, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { self.logger.debug("Cleaning up and closing connection.", metadata: [.error: "\(cleanup.error)"]) // 1. fail all tasks diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift index 2017cd1a..3963bd62 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift @@ -5,6 +5,98 @@ import NIOCore /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) enum PostgresFrontendMessage: Equatable { + + struct Bind: Hashable { + /// The name of the destination portal (an empty string selects the unnamed portal). + var portalName: String + + /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). + var preparedStatementName: String + + /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. + var bind: PostgresBindings + } + + struct Cancel: Equatable { + /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same + /// as any protocol version number.) + static let requestCode: Int32 = 80877102 + + /// The process ID of the target backend. + let processID: Int32 + + /// The secret key for the target backend. + let secretKey: Int32 + } + + enum Close: Hashable { + case preparedStatement(String) + case portal(String) + } + + enum Describe: Hashable { + case preparedStatement(String) + case portal(String) + } + + struct Execute: Hashable { + /// The name of the portal to execute (an empty string selects the unnamed portal). + let portalName: String + + /// Maximum number of rows to return, if portal contains a query that returns rows (ignored otherwise). Zero denotes “no limit”. + let maxNumberOfRows: Int32 + + init(portalName: String, maxNumberOfRows: Int32 = 0) { + self.portalName = portalName + self.maxNumberOfRows = maxNumberOfRows + } + } + + struct Parse: Hashable { + /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). + let preparedStatementName: String + + /// The query string to be parsed. + let query: String + + /// The number of parameter data types specified (can be zero). Note that this is not an indication of the number of parameters that might appear in the query string, only the number that the frontend wants to prespecify types for. + let parameters: [PostgresDataType] + } + + struct Password: Hashable { + let value: String + } + + struct SASLInitialResponse: Hashable { + + let saslMechanism: String + let initialData: [UInt8] + + /// Creates a new `SSLRequest`. + init(saslMechanism: String, initialData: [UInt8]) { + self.saslMechanism = saslMechanism + self.initialData = initialData + } + } + + struct SASLResponse: Hashable { + var data: [UInt8] + + /// Creates a new `SSLRequest`. + init(data: [UInt8]) { + self.data = data + } + } + + /// A message asking the PostgreSQL server if TLS is supported + /// For more info, see https://www.postgresql.org/docs/10/static/protocol-flow.html#id-1.10.5.7.11 + struct SSLRequest: Hashable { + /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5679 in the least significant 16 bits. + static let requestCode: Int32 = 80877103 + } + case bind(Bind) case cancel(Cancel) case close(Close) @@ -15,7 +107,7 @@ enum PostgresFrontendMessage: Equatable { case password(Password) case saslInitialResponse(SASLInitialResponse) case saslResponse(SASLResponse) - case sslRequest(SSLRequest) + case sslRequest case sync case startup(Startup) case terminate diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift new file mode 100644 index 00000000..46dbba42 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -0,0 +1,205 @@ +import NIOCore + +struct PostgresFrontendMessageEncoder { + private enum State { + case flushed + case writable + } + + private var buffer: ByteBuffer + private var state: State = .writable + + init(buffer: ByteBuffer) { + self.buffer = buffer + } + + mutating func startup(_ parameters: PostgresFrontendMessage.Startup.Parameters) { + self.clearIfNeeded() + self.encodeLengthPrefixed { buffer in + buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree) + buffer.writeNullTerminatedString("user") + buffer.writeNullTerminatedString(parameters.user) + + if let database = parameters.database { + buffer.writeNullTerminatedString("database") + buffer.writeNullTerminatedString(database) + } + + if let options = parameters.options { + buffer.writeNullTerminatedString("options") + buffer.writeNullTerminatedString(options) + } + + switch parameters.replication { + case .database: + buffer.writeNullTerminatedString("replication") + buffer.writeNullTerminatedString("replication") + case .true: + buffer.writeNullTerminatedString("replication") + buffer.writeNullTerminatedString("true") + case .false: + break + } + + buffer.writeInteger(UInt8(0)) + } + } + + mutating func bind(portalName: String, preparedStatementName: String, bind: PostgresBindings) { + self.clearIfNeeded() + self.buffer.psqlWriteFrontendMessageID(.bind) + self.encodeLengthPrefixed { buffer in + buffer.writeNullTerminatedString(portalName) + buffer.writeNullTerminatedString(preparedStatementName) + + // The number of parameter format codes that follow (denoted C below). This can be + // zero to indicate that there are no parameters or that the parameters all use the + // default format (text); or one, in which case the specified format code is applied + // to all parameters; or it can equal the actual number of parameters. + buffer.writeInteger(UInt16(bind.count)) + + // The parameter format codes. Each must presently be zero (text) or one (binary). + bind.metadata.forEach { + buffer.writeInteger($0.format.rawValue) + } + + buffer.writeInteger(UInt16(bind.count)) + + var parametersCopy = bind.bytes + buffer.writeBuffer(¶metersCopy) + + // The number of result-column format codes that follow (denoted R below). This can be + // zero to indicate that there are no result columns or that the result columns should + // all use the default format (text); or one, in which case the specified format code + // is applied to all result columns (if any); or it can equal the actual number of + // result columns of the query. + buffer.writeInteger(1, as: Int16.self) + // The result-column format codes. Each must presently be zero (text) or one (binary). + buffer.writeInteger(PostgresFormat.binary.rawValue, as: Int16.self) + } + } + + mutating func cancel(processID: Int32, secretKey: Int32) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(UInt32(16), PostgresFrontendMessage.Cancel.requestCode, processID, secretKey) + } + + mutating func closePreparedStatement(_ preparedStatement: String) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.writeNullTerminatedString(preparedStatement) + } + + mutating func closePortal(_ portal: String) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.writeNullTerminatedString(portal) + } + + mutating func describePreparedStatement(_ preparedStatement: String) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.writeNullTerminatedString(preparedStatement) + } + + mutating func describePortal(_ portal: String) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.writeNullTerminatedString(portal) + } + + mutating func execute(portalName: String, maxNumberOfRows: Int32 = 0) { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.execute.rawValue, UInt32(9 + portalName.utf8.count)) + self.buffer.writeNullTerminatedString(portalName) + self.buffer.writeInteger(maxNumberOfRows) + } + + mutating func parse(preparedStatementName: String, query: String, parameters: Parameters) where Parameters.Element == PostgresDataType { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers( + PostgresFrontendMessage.ID.parse.rawValue, + UInt32(4 + preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout.size * parameters.count) + ) + self.buffer.writeNullTerminatedString(preparedStatementName) + self.buffer.writeNullTerminatedString(query) + self.buffer.writeInteger(UInt16(parameters.count)) + + for dataType in parameters { + self.buffer.writeInteger(dataType.rawValue) + } + } + + mutating func password(_ bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.password.rawValue, UInt32(5 + bytes.count)) + self.buffer.writeBytes(bytes) + self.buffer.writeInteger(UInt8(0)) + } + + mutating func flush() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.flush.rawValue, UInt32(4)) + } + + mutating func saslResponse(_ bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.saslResponse.rawValue, UInt32(4 + bytes.count)) + self.buffer.writeBytes(bytes) + } + + mutating func saslInitialResponse(mechanism: String, bytes: Bytes) where Bytes.Element == UInt8 { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers( + PostgresFrontendMessage.ID.saslInitialResponse.rawValue, + UInt32(4 + mechanism.utf8.count + 1 + 4 + bytes.count) + ) + self.buffer.writeNullTerminatedString(mechanism) + if bytes.count > 0 { + self.buffer.writeInteger(Int32(bytes.count)) + self.buffer.writeBytes(bytes) + } else { + self.buffer.writeInteger(Int32(-1)) + } + } + + mutating func ssl() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(UInt32(8), PostgresFrontendMessage.SSLRequest.requestCode) + } + + mutating func sync() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.sync.rawValue, UInt32(4)) + } + + mutating func terminate() { + self.clearIfNeeded() + self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.terminate.rawValue, UInt32(4)) + } + + mutating func flushBuffer() -> ByteBuffer { + self.state = .flushed + return self.buffer + } + + private mutating func clearIfNeeded() { + switch self.state { + case .flushed: + self.state = .writable + self.buffer.clear() + + case .writable: + break + } + } + + private mutating func encodeLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) { + let startIndex = self.buffer.writerIndex + self.buffer.writeInteger(UInt32(0)) // placeholder for length + encode(&self.buffer) + let length = UInt32(self.buffer.writerIndex - startIndex) + self.buffer.setInteger(length, at: startIndex) + } + +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 311c41bd..342907ea 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -34,7 +34,7 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { switch code { case 80877103: self.isInStartup = true - return .sslRequest(.init()) + return .sslRequest case 196608: var user: String? diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index 85768b10..d5ec5b30 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -5,15 +5,15 @@ import NIOCore class BindTests: XCTestCase { func testEncodeBind() { - let encoder = PSQLFrontendMessageEncoder() var bindings = PostgresBindings() bindings.append("Hello", context: .default) bindings.append("World", context: .default) - var byteBuffer = ByteBuffer() - let bind = PostgresFrontendMessage.Bind(portalName: "", preparedStatementName: "", bind: bindings) - let message = PostgresFrontendMessage.bind(bind) - encoder.encode(data: message, out: &byteBuffer) - + + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + + encoder.bind(portalName: "", preparedStatementName: "", bind: bindings) + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 37) XCTAssertEqual(PostgresFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift index c42f1999..5548aae3 100644 --- a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -5,18 +5,17 @@ import NIOCore class CancelTests: XCTestCase { func testEncodeCancel() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let cancel = PostgresFrontendMessage.Cancel(processID: 1234, secretKey: 4567) - let message = PostgresFrontendMessage.cancel(cancel) - encoder.encode(data: message, out: &byteBuffer) + let processID: Int32 = 1234 + let secretKey: Int32 = 4567 + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.cancel(processID: processID, secretKey: secretKey) + var byteBuffer = encoder.flushBuffer() XCTAssertEqual(byteBuffer.readableBytes, 16) XCTAssertEqual(16, byteBuffer.readInteger(as: Int32.self)) // payload length XCTAssertEqual(80877102, byteBuffer.readInteger(as: Int32.self)) // cancel request code - XCTAssertEqual(cancel.processID, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(cancel.secretKey, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(processID, byteBuffer.readInteger(as: Int32.self)) + XCTAssertEqual(secretKey, byteBuffer.readInteger(as: Int32.self)) XCTAssertEqual(byteBuffer.readableBytes, 0) } - } diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index f6a0237b..a8e1cfeb 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -3,13 +3,11 @@ import NIOCore @testable import PostgresNIO class CloseTests: XCTestCase { - func testEncodeClosePortal() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.close(.portal("Hello")) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.closePortal("Hello") + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 12) XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) @@ -19,11 +17,10 @@ class CloseTests: XCTestCase { } func testEncodeCloseUnnamedStatement() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.close(.preparedStatement("")) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.closePreparedStatement("") + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 7) XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) @@ -31,5 +28,4 @@ class CloseTests: XCTestCase { XCTAssertEqual("", byteBuffer.readNullTerminatedString()) XCTAssertEqual(byteBuffer.readableBytes, 0) } - } diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index df26f3d7..cb3c745b 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -5,11 +5,10 @@ import NIOCore class DescribeTests: XCTestCase { func testEncodeDescribePortal() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.describe(.portal("Hello")) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.describePortal("Hello") + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 12) XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) @@ -19,11 +18,10 @@ class DescribeTests: XCTestCase { } func testEncodeDescribeUnnamedStatement() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.describe(.preparedStatement("")) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.describePreparedStatement("") + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 7) XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index dc5e2767..834ad0dd 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -5,11 +5,10 @@ import NIOCore class ExecuteTests: XCTestCase { func testEncodeExecute() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let message = PostgresFrontendMessage.execute(.init(portalName: "", maxNumberOfRows: 0)) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.execute(portalName: "", maxNumberOfRows: 0) + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) XCTAssertEqual(PostgresFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 723ad1e6..9f81e4e4 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -3,18 +3,19 @@ import NIOCore @testable import PostgresNIO class ParseTests: XCTestCase { - func testEncode() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let parse = PostgresFrontendMessage.Parse( - preparedStatementName: "test", - query: "SELECT version()", - parameters: [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray]) - let message = PostgresFrontendMessage.parse(parse) - encoder.encode(data: message, out: &byteBuffer) + let preparedStatementName = "test" + let query = "SELECT version()" + let parameters: [PostgresDataType] = [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray] + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.parse( + preparedStatementName: preparedStatementName, + query: query, + parameters: parameters + ) + var byteBuffer = encoder.flushBuffer() - let length: Int = 1 + 4 + (parse.preparedStatementName.count + 1) + (parse.query.count + 1) + 2 + parse.parameters.count * 4 + let length: Int = 1 + 4 + (preparedStatementName.count + 1) + (query.count + 1) + 2 + parameters.count * 4 // 1 id // + 4 length @@ -24,17 +25,11 @@ class ParseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.parse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query) - XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parse.parameters.count)) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bool.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.int8.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bytea.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.varchar.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.text.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.uuid.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.json.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.jsonbArray.rawValue) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), preparedStatementName) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), query) + XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parameters.count)) + for dataType in parameters { + XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), dataType.rawValue) + } } - } diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 7572d382..4a4833d2 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -5,11 +5,11 @@ import NIOCore class PasswordTests: XCTestCase { func testEncodePassword() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) // md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4 - let message = PostgresFrontendMessage.password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a")) - encoder.encode(data: message, out: &byteBuffer) + let password = "md522d085ed8dc3377968dc1c1a40519a2a" + encoder.password(password.utf8) + var byteBuffer = encoder.flushBuffer() let expectedLength = 41 // 1 (id) + 4 (length) + 35 (string) + 1 (null termination) diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index 08b3097d..90aa6b34 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -4,15 +4,14 @@ import NIOCore class SASLInitialResponseTests: XCTestCase { - func testEncodeWithData() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let sasl = PostgresFrontendMessage.SASLInitialResponse( - saslMechanism: "hello", initialData: [0, 1, 2, 3, 4, 5, 6, 7]) - let message = PostgresFrontendMessage.saslInitialResponse(sasl) - encoder.encode(data: message, out: &byteBuffer) + func testEncode() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let saslMechanism = "hello" + let initialData: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] + encoder.saslInitialResponse(mechanism: saslMechanism, bytes: initialData) + var byteBuffer = encoder.flushBuffer() - let length: Int = 1 + 4 + (sasl.saslMechanism.count + 1) + 4 + sasl.initialData.count + let length: Int = 1 + 4 + (saslMechanism.count + 1) + 4 + initialData.count // 1 id // + 4 length @@ -23,21 +22,20 @@ class SASLInitialResponseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(sasl.initialData.count)) - XCTAssertEqual(byteBuffer.readBytes(length: sasl.initialData.count), sasl.initialData) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(initialData.count)) + XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) XCTAssertEqual(byteBuffer.readableBytes, 0) } func testEncodeWithoutData() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let sasl = PostgresFrontendMessage.SASLInitialResponse( - saslMechanism: "hello", initialData: []) - let message = PostgresFrontendMessage.saslInitialResponse(sasl) - encoder.encode(data: message, out: &byteBuffer) + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let saslMechanism = "hello" + let initialData: [UInt8] = [] + encoder.saslInitialResponse(mechanism: saslMechanism, bytes: initialData) + var byteBuffer = encoder.flushBuffer() - let length: Int = 1 + 4 + (sasl.saslMechanism.count + 1) + 4 + sasl.initialData.count + let length: Int = 1 + 4 + (saslMechanism.count + 1) + 4 + initialData.count // 1 id // + 4 length @@ -48,8 +46,9 @@ class SASLInitialResponseTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), sasl.saslMechanism) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) + XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) XCTAssertEqual(byteBuffer.readableBytes, 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index e148420f..cdb0f10b 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -5,28 +5,26 @@ import NIOCore class SASLResponseTests: XCTestCase { func testEncodeWithData() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let sasl = PostgresFrontendMessage.SASLResponse(data: [0, 1, 2, 3, 4, 5, 6, 7]) - let message = PostgresFrontendMessage.saslResponse(sasl) - encoder.encode(data: message, out: &byteBuffer) - - let length: Int = 1 + 4 + (sasl.data.count) + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let data: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] + encoder.saslResponse(data) + var byteBuffer = encoder.flushBuffer() + + let length: Int = 1 + 4 + (data.count) XCTAssertEqual(byteBuffer.readableBytes, length) XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readBytes(length: sasl.data.count), sasl.data) + XCTAssertEqual(byteBuffer.readBytes(length: data.count), data) XCTAssertEqual(byteBuffer.readableBytes, 0) } func testEncodeWithoutData() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let sasl = PostgresFrontendMessage.SASLResponse(data: []) - let message = PostgresFrontendMessage.saslResponse(sasl) - encoder.encode(data: message, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let data: [UInt8] = [] + encoder.saslResponse(data) + var byteBuffer = encoder.flushBuffer() + let length: Int = 1 + 4 XCTAssertEqual(byteBuffer.readableBytes, length) diff --git a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift index 9a973f2b..e9e6af81 100644 --- a/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift @@ -5,16 +5,14 @@ import NIOCore class SSLRequestTests: XCTestCase { func testSSLRequest() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - let request = PostgresFrontendMessage.SSLRequest() - let message = PostgresFrontendMessage.sslRequest(request) - encoder.encode(data: message, out: &byteBuffer) + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.ssl() + var byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(request.code, byteBuffer.readInteger()) - + XCTAssertEqual(PostgresFrontendMessage.SSLRequest.requestCode, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readableBytes, 0) } diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 08a9ee21..e72f0f34 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -5,7 +5,7 @@ import NIOCore class StartupTests: XCTestCase { func testStartupMessage() { - let encoder = PSQLFrontendMessageEncoder() + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() let replicationValues: [PostgresFrontendMessage.Startup.Parameters.Replication] = [ @@ -22,13 +22,12 @@ class StartupTests: XCTestCase { replication: replication ) - let startup = PostgresFrontendMessage.Startup.versionThree(parameters: parameters) - let message = PostgresFrontendMessage.startup(startup) - encoder.encode(data: message, out: &byteBuffer) - + encoder.startup(parameters) + byteBuffer = encoder.flushBuffer() + let byteBufferLength = Int32(byteBuffer.readableBytes) XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(startup.protocolVersion, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift index 59b69bae..33afbe0d 100644 --- a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -23,30 +23,30 @@ class PSQLFrontendMessageTests: XCTestCase { // MARK: Encoder func testEncodeFlush() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - encoder.encode(data: .flush, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.flush() + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PostgresFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } func testEncodeSync() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - encoder.encode(data: .sync, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.sync() + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PostgresFrontendMessage.ID.sync.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } func testEncodeTerminate() { - let encoder = PSQLFrontendMessageEncoder() - var byteBuffer = ByteBuffer() - encoder.encode(data: .terminate, out: &byteBuffer) - + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.terminate() + var byteBuffer = encoder.flushBuffer() + XCTAssertEqual(byteBuffer.readableBytes, 5) XCTAssertEqual(PostgresFrontendMessage.ID.terminate.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index d76b8223..97ad892f 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -49,15 +49,9 @@ class PostgresChannelHandlerTests: XCTestCase { handler ]) - var maybeMessage: PostgresFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) - XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) - guard case .sslRequest(let request) = maybeMessage else { - return XCTFail("Unexpected message") - } - - XCTAssertEqual(request.code, 80877103) - + XCTAssertEqual(.sslRequest, try embedded.readOutbound(as: PostgresFrontendMessage.self)) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.sslSupported)) // a NIOSSLHandler has been added, after it SSL had been negotiated @@ -92,14 +86,8 @@ class PostgresChannelHandlerTests: XCTestCase { eventHandler ]) - var maybeMessage: PostgresFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) - XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) - guard case .sslRequest(let request) = maybeMessage else { - return XCTFail("Unexpected message") - } - - XCTAssertEqual(request.code, 80877103) + XCTAssertEqual(.sslRequest, try embedded.readOutbound(as: PostgresFrontendMessage.self)) var responseBuffer = ByteBuffer() responseBuffer.writeInteger(UInt8(ascii: "S")) @@ -134,7 +122,7 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertTrue(embedded.isActive) // read the ssl request message - XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .sslRequest(.init())) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .sslRequest) try embedded.writeInbound(PostgresBackendMessage.sslUnsupported) // the event handler should have seen an error From 0c9391c68a38be8d9990688717fe26eaad41e395 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 5 Aug 2023 14:49:47 +0200 Subject: [PATCH 166/292] Add async listen; Refactor all listen code (#264) --- .../Connection/PostgresConnection.swift | 139 +++++----- .../ConnectionStateMachine.swift | 5 +- .../ListenStateMachine.swift | 247 ++++++++++++++++++ .../New/NotificationListener.swift | 157 +++++++++++ Sources/PostgresNIO/New/PSQLError.swift | 15 ++ Sources/PostgresNIO/New/PSQLTask.swift | 12 +- .../New/PostgresChannelHandler.swift | 241 ++++++++++++++--- .../New/PostgresFrontendMessage.swift | 7 +- .../New/PostgresNotificationSequence.swift | 22 ++ Sources/PostgresNIO/Postgres+PSQLCompat.swift | 4 +- Tests/IntegrationTests/AsyncTests.swift | 23 ++ .../PSQLFrontendMessageDecoder.swift | 81 +++++- .../New/PSQLConnectionTests.swift | 37 --- .../New/PostgresChannelHandlerTests.swift | 43 +-- .../New/PostgresConnectionTests.swift | 245 +++++++++++++++++ 15 files changed, 1104 insertions(+), 174 deletions(-) create mode 100644 Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift create mode 100644 Sources/PostgresNIO/New/NotificationListener.swift create mode 100644 Sources/PostgresNIO/New/PostgresNotificationSequence.swift delete mode 100644 Tests/PostgresNIOTests/New/PSQLConnectionTests.swift create mode 100644 Tests/PostgresNIOTests/New/PostgresConnectionTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index c24041c9..d6420a6e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -38,15 +38,7 @@ public final class PostgresConnection: @unchecked Sendable { } } - /// A dictionary to store notification callbacks in - /// - /// Those are used when `PostgresConnection.addListener` is invoked. This only lives here since properties - /// can not be added in extensions. All relevant code lives in `PostgresConnection+Notifications` - var notificationListeners: [String: [(PostgresListenContext, (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void)]] = [:] { - willSet { - self.channel.eventLoop.preconditionInEventLoop() - } - } + private let internalListenID = ManagedAtomic(0) public var isClosed: Bool { return !self.channel.isActive @@ -87,10 +79,10 @@ public final class PostgresConnection: @unchecked Sendable { let channelHandler = PostgresChannelHandler( configuration: configuration, + eventLoop: channel.eventLoop, logger: logger, configureSSLCallback: configureSSLCallback ) - channelHandler.notificationDelegate = self let eventHandler = PSQLEventsHandler(logger: logger) @@ -164,14 +156,16 @@ public final class PostgresConnection: @unchecked Sendable { // thread and the EventLoop. return eventLoop.flatSubmit { () -> EventLoopFuture in let connectFuture: EventLoopFuture - let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) switch configuration.connection { case .resolved(let address): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) connectFuture = bootstrap.connect(to: address) case .unresolvedTCP(let host, let port): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) connectFuture = bootstrap.connect(host: host, port: port) case .unresolvedUDS(let path): + let bootstrap = self.makeBootstrap(on: eventLoop, configuration: configuration) connectFuture = bootstrap.connect(unixDomainSocketPath: path) case .bootstrapped(let channel): guard channel.isActive else { @@ -224,9 +218,10 @@ public final class PostgresConnection: @unchecked Sendable { let context = ExtendedQueryContext( query: query, logger: logger, - promise: promise) + promise: promise + ) - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -241,7 +236,7 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.channel.write(PSQLTask.preparedStatement(context), promise: nil) + self.channel.write(HandlerTask.preparedStatement(context), promise: nil) return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -257,7 +252,7 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -265,7 +260,7 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.channel.write(PSQLTask.closeCommand(context), promise: nil) + self.channel.write(HandlerTask.closeCommand(context), promise: nil) return promise.futureResult } @@ -417,7 +412,7 @@ extension PostgresConnection { promise: promise ) - self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -428,6 +423,31 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + /// Start listening for a channel + public func listen(_ channel: String) async throws -> PostgresNotificationSequence { + let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + try Task.checkCancellation() + + return try await withCheckedThrowingContinuation { continuation in + let listener = NotificationListener( + channel: channel, + id: id, + eventLoop: self.eventLoop, + checkedContinuation: continuation + ) + + let task = HandlerTask.startListening(listener) + + self.channel.write(task, promise: nil) + } + } onCancel: { + let task = HandlerTask.cancelListening(channel, id) + self.channel.write(task, promise: nil) + } + } } // MARK: EventLoopFuture interface @@ -569,73 +589,58 @@ internal enum PostgresCommands: PostgresRequest { // MARK: Notifications /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. -public final class PostgresListenContext { - var stopper: (() -> Void)? +public final class PostgresListenContext: Sendable { + private let promise: EventLoopPromise + + var future: EventLoopFuture { + self.promise.futureResult + } + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func cancel() { + self.promise.succeed() + } /// Detach this listener so it no longer receives notifications. Other listeners, including those for the same channel, are unaffected. `UNLISTEN` is not sent; you are responsible for issuing an `UNLISTEN` query yourself if it is appropriate for your application. public func stop() { - stopper?() - stopper = nil + self.promise.succeed() } } extension PostgresConnection { /// Add a handler for NotificationResponse messages on a certain channel. This is used in conjunction with PostgreSQL's `LISTEN`/`NOTIFY` support: to listen on a channel, you add a listener using this method to handle the NotificationResponse messages, then issue a `LISTEN` query to instruct PostgreSQL to begin sending NotificationResponse messages. @discardableResult - public func addListener(channel: String, handler notificationHandler: @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void) -> PostgresListenContext { + @preconcurrency + public func addListener( + channel: String, + handler notificationHandler: @Sendable @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void + ) -> PostgresListenContext { + let listenContext = PostgresListenContext(promise: self.eventLoop.makePromise(of: Void.self)) + let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed) + + let listener = NotificationListener( + channel: channel, + id: id, + eventLoop: self.eventLoop, + context: listenContext, + closure: notificationHandler + ) - let listenContext = PostgresListenContext() + let task = HandlerTask.startListening(listener) + self.channel.write(task, promise: nil) - self.channel.pipeline.handler(type: PostgresChannelHandler.self).whenSuccess { handler in - if self.notificationListeners[channel] != nil { - self.notificationListeners[channel]!.append((listenContext, notificationHandler)) - } - else { - self.notificationListeners[channel] = [(listenContext, notificationHandler)] - } - } - - listenContext.stopper = { [weak self, weak listenContext] in - // self is weak, since the connection can long be gone, when the listeners stop is - // triggered. listenContext must be weak to prevent a retain cycle - - self?.channel.eventLoop.execute { - guard - let self = self, // the connection is already gone - var listeners = self.notificationListeners[channel] // we don't have the listeners for this topic ¯\_(ツ)_/¯ - else { - return - } - - assert(listeners.filter { $0.0 === listenContext }.count <= 1, "Listeners can not appear twice in a channel!") - listeners.removeAll(where: { $0.0 === listenContext }) // just in case a listener shows up more than once in a release build, remove all, not just first - self.notificationListeners[channel] = listeners.isEmpty ? nil : listeners - } + listenContext.future.whenComplete { _ in + let task = HandlerTask.cancelListening(channel, id) + self.channel.write(task, promise: nil) } return listenContext } } -extension PostgresConnection: PSQLChannelHandlerNotificationDelegate { - func notificationReceived(_ notification: PostgresBackendMessage.NotificationResponse) { - self.eventLoop.assertInEventLoop() - - guard let listeners = self.notificationListeners[notification.channel] else { - return - } - - let postgresNotification = PostgresMessage.NotificationResponse( - backendPID: notification.backendPID, - channel: notification.channel, - payload: notification.payload) - - listeners.forEach { (listenContext, handler) in - handler(listenContext, postgresNotification) - } - } -} - enum CloseTarget { case preparedStatement(String) case portal(String) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index ba1e3c1f..761ba5f2 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1091,11 +1091,12 @@ extension ConnectionStateMachine { .tooManyParameters, .invalidCommandTag, .connectionError, - .uncleanShutdown: + .uncleanShutdown, + .unlistenFailed: return true case .queryCancelled: return false - case .server: + case .server, .listenFailed: guard let sqlState = error.serverInfo?[.sqlState] else { // any error message that doesn't have a sql state field, is unexpected by default. return true diff --git a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift new file mode 100644 index 00000000..c7f92428 --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift @@ -0,0 +1,247 @@ +import NIOCore + +struct ListenStateMachine { + var channels: [String: ChannelState] + + init() { + self.channels = [:] + } + + enum StartListeningAction { + case none + case startListening(String) + case succeedListenStart(NotificationListener) + } + + mutating func startListening(_ new: NotificationListener) -> StartListeningAction { + return self.channels[new.channel, default: .init()].start(new) + } + + enum StartListeningSuccessAction { + case stopListening + case activateListeners(Dictionary.Values) + } + + mutating func startListeningSucceeded(channel: String) -> StartListeningSuccessAction { + return self.channels[channel]!.startListeningSucceeded() + } + + mutating func startListeningFailed(channel: String, error: Error) -> Dictionary.Values { + return self.channels[channel]!.startListeningFailed(error) + } + + enum StopListeningSuccessAction { + case startListening + case none + } + + mutating func stopListeningSucceeded(channel: String) -> StopListeningSuccessAction { + return self.channels[channel, default: .init()].stopListeningSucceeded() + } + + enum CancelAction { + case stopListening(String, cancelListener: NotificationListener) + case cancelListener(NotificationListener) + case none + } + + mutating func cancelNotificationListener(channel: String, id: Int) -> CancelAction { + return self.channels[channel, default: .init()].cancelListening(id: id) + } + + mutating func fail(_ error: Error) -> [NotificationListener] { + var result = [NotificationListener]() + while var (_, channel) = self.channels.popFirst() { + switch channel.fail(error) { + case .none: + continue + + case .failListeners(let listeners): + result.append(contentsOf: listeners) + } + } + return result + } + + enum ReceivedAction { + case none + case notify(Dictionary.Values) + } + + func notificationReceived(channel: String) -> ReceivedAction { + // TODO: Do we want to close the connection, if we receive a notification on a channel that we don't listen to? + // We can only change this with the next major release, as it would break current functionality. + return self.channels[channel]?.notificationReceived() ?? .none + } +} + +extension ListenStateMachine { + struct ChannelState { + enum State { + case initialized + case starting([Int: NotificationListener]) + case listening([Int: NotificationListener]) + case stopping([Int: NotificationListener]) + case failed(Error) + } + + private var state: State + + init() { + self.state = .initialized + } + + mutating func start(_ new: NotificationListener) -> StartListeningAction { + switch self.state { + case .initialized: + self.state = .starting([new.id: new]) + return .startListening(new.channel) + + case .starting(var listeners): + listeners[new.id] = new + self.state = .starting(listeners) + return .none + + case .listening(var listeners): + listeners[new.id] = new + self.state = .listening(listeners) + return .succeedListenStart(new) + + case .stopping(var listeners): + listeners[new.id] = new + self.state = .stopping(listeners) + return .none + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func startListeningSucceeded() -> StartListeningSuccessAction { + switch self.state { + case .initialized, .listening, .stopping: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners): + if listeners.isEmpty { + self.state = .stopping(listeners) + return .stopListening + } else { + self.state = .listening(listeners) + return .activateListeners(listeners.values) + } + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func startListeningFailed(_ error: Error) -> Dictionary.Values { + switch self.state { + case .initialized, .listening, .stopping: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners): + self.state = .initialized + return listeners.values + + case .failed: + fatalError("Invalid state: \(self.state)") + } + } + + mutating func stopListeningSucceeded() -> StopListeningSuccessAction { + switch self.state { + case .initialized, .listening, .starting: + fatalError("Invalid state: \(self.state)") + + case .stopping(let listeners): + if listeners.isEmpty { + self.state = .initialized + return .none + } else { + self.state = .starting(listeners) + return .startListening + } + + case .failed: + return .none + } + } + + mutating func cancelListening(id: Int) -> CancelAction { + switch self.state { + case .initialized: + fatalError("Invalid state: \(self.state)") + + case .starting(var listeners): + let removed = listeners.removeValue(forKey: id) + self.state = .starting(listeners) + if let removed = removed { + return .cancelListener(removed) + } + return .none + + case .listening(var listeners): + precondition(!listeners.isEmpty) + let maybeLast = listeners.removeValue(forKey: id) + if let last = maybeLast, listeners.isEmpty { + self.state = .stopping(listeners) + return .stopListening(last.channel, cancelListener: last) + } else { + self.state = .listening(listeners) + if let notLast = maybeLast { + return .cancelListener(notLast) + } + return .none + } + + case .stopping(var listeners): + let removed = listeners.removeValue(forKey: id) + self.state = .stopping(listeners) + if let removed = removed { + return .cancelListener(removed) + } + return .none + + case .failed: + return .none + } + } + + enum FailAction { + case failListeners(Dictionary.Values) + case none + } + + mutating func fail(_ error: Error) -> FailAction { + switch self.state { + case .initialized: + fatalError("Invalid state: \(self.state)") + + case .starting(let listeners), .listening(let listeners), .stopping(let listeners): + self.state = .failed(error) + return .failListeners(listeners.values) + + case .failed: + return .none + } + } + + func notificationReceived() -> ReceivedAction { + switch self.state { + case .initialized, .starting: + fatalError("Invalid state: \(self.state)") + + case .listening(let listeners): + return .notify(listeners.values) + + case .stopping: + return .none + + default: + preconditionFailure("TODO: Implemented") + } + } + } +} diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift new file mode 100644 index 00000000..5f4bc3de --- /dev/null +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -0,0 +1,157 @@ +import NIOCore + +// This object is @unchecked Sendable, since we syncronize state on the EL +final class NotificationListener: @unchecked Sendable { + let eventLoop: EventLoop + + let channel: String + let id: Int + + private var state: State + + enum State { + case streamInitialized(CheckedContinuation) + case streamListening(AsyncThrowingStream.Continuation) + + case closure(PostgresListenContext, (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void) + case done + } + + init( + channel: String, + id: Int, + eventLoop: EventLoop, + checkedContinuation: CheckedContinuation + ) { + self.channel = channel + self.id = id + self.eventLoop = eventLoop + self.state = .streamInitialized(checkedContinuation) + } + + init( + channel: String, + id: Int, + eventLoop: EventLoop, + context: PostgresListenContext, + closure: @Sendable @escaping (PostgresListenContext, PostgresMessage.NotificationResponse) -> Void + ) { + self.channel = channel + self.id = id + self.eventLoop = eventLoop + self.state = .closure(context, closure) + } + + func startListeningSucceeded(handler: PostgresChannelHandler) { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized(let checkedContinuation): + let (stream, continuation) = AsyncThrowingStream.makeStream(of: PostgresNotification.self) + let eventLoop = self.eventLoop + let channel = self.channel + let listenerID = self.id + continuation.onTermination = { reason in + switch reason { + case .cancelled: + eventLoop.execute { + handler.cancelNotificationListener(channel: channel, id: listenerID) + } + + case .finished: + break + + @unknown default: + break + } + } + self.state = .streamListening(continuation) + + let notificationSequence = PostgresNotificationSequence(base: stream) + checkedContinuation.resume(returning: notificationSequence) + + case .streamListening, .done: + fatalError("Invalid state: \(self.state)") + + case .closure: + break // ignore + } + } + + func notificationReceived(_ backendMessage: PostgresBackendMessage.NotificationResponse) { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized, .done: + fatalError("Invalid state: \(self.state)") + case .streamListening(let continuation): + continuation.yield(.init(payload: backendMessage.payload)) + + case .closure(let postgresListenContext, let closure): + let message = PostgresMessage.NotificationResponse( + backendPID: backendMessage.backendPID, + channel: backendMessage.channel, + payload: backendMessage.payload + ) + closure(postgresListenContext, message) + } + } + + func failed(_ error: Error) { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized(let checkedContinuation): + self.state = .done + checkedContinuation.resume(throwing: error) + + case .streamListening(let continuation): + self.state = .done + continuation.finish(throwing: error) + + case .closure(let postgresListenContext, _): + self.state = .done + postgresListenContext.cancel() + + case .done: + break // ignore + } + } + + func cancelled() { + self.eventLoop.preconditionInEventLoop() + + switch self.state { + case .streamInitialized(let checkedContinuation): + self.state = .done + checkedContinuation.resume(throwing: PSQLError(code: .queryCancelled)) + + case .streamListening(let continuation): + self.state = .done + continuation.finish() + + case .closure(let postgresListenContext, _): + self.state = .done + postgresListenContext.cancel() + + case .done: + break // ignore + } + } +} + + +#if swift(<5.9) +// Async stream API backfill +extension AsyncThrowingStream { + static func makeStream( + of elementType: Element.Type = Element.self, + throwing failureType: Failure.Type = Failure.self, + bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded + ) -> (stream: AsyncThrowingStream, continuation: AsyncThrowingStream.Continuation) where Failure == Error { + var continuation: AsyncThrowingStream.Continuation! + let stream = AsyncThrowingStream(bufferingPolicy: limit) { continuation = $0 } + return (stream: stream, continuation: continuation!) + } + } + #endif diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index df7dd7c1..a13d4209 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -22,6 +22,9 @@ public struct PSQLError: Error { case connectionClosed case connectionError case uncleanShutdown + + case listenFailed + case unlistenFailed } internal var base: Base @@ -46,6 +49,8 @@ public struct PSQLError: Error { public static let connectionClosed = Self(.connectionClosed) public static let connectionError = Self(.connectionError) public static let uncleanShutdown = Self.init(.uncleanShutdown) + public static let listenFailed = Self.init(.listenFailed) + public static let unlistenFailed = Self.init(.unlistenFailed) public var description: String { switch self.base { @@ -81,6 +86,10 @@ public struct PSQLError: Error { return "connectionError" case .uncleanShutdown: return "uncleanShutdown" + case .listenFailed: + return "listenFailed" + case .unlistenFailed: + return "unlistenFailed" } } } @@ -418,6 +427,12 @@ public struct PSQLError: Error { return error } + static func unlistenError(underlying: Error) -> PSQLError { + var error = PSQLError(code: .unlistenFailed) + error.underlying = underlying + return error + } + enum UnsupportedAuthScheme { case none case kerberosV5 diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index f9ca1232..26312c0c 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -1,17 +1,27 @@ import Logging import NIOCore +enum HandlerTask { + case extendedQuery(ExtendedQueryContext) + case preparedStatement(PrepareStatementContext) + case closeCommand(CloseCommandContext) + case startListening(NotificationListener) + case cancelListening(String, Int) +} + enum PSQLTask { case extendedQuery(ExtendedQueryContext) case preparedStatement(PrepareStatementContext) case closeCommand(CloseCommandContext) - + func failWithError(_ error: PSQLError) { switch self { case .extendedQuery(let extendedQueryContext): extendedQueryContext.promise.fail(error) + case .preparedStatement(let createPreparedStatementContext): createPreparedStatementContext.promise.fail(error) + case .closeCommand(let closeCommandContext): closeCommandContext.promise.fail(error) } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 09feb521..4470e802 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -3,16 +3,13 @@ import NIOTLS import Crypto import Logging -protocol PSQLChannelHandlerNotificationDelegate: AnyObject { - func notificationReceived(_: PostgresBackendMessage.NotificationResponse) -} - final class PostgresChannelHandler: ChannelDuplexHandler { - typealias OutboundIn = PSQLTask + typealias OutboundIn = HandlerTask typealias InboundIn = ByteBuffer typealias OutboundOut = ByteBuffer private let logger: Logger + private let eventLoop: EventLoop private var state: ConnectionStateMachine /// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed). @@ -24,15 +21,18 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private var encoder: PostgresFrontendMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? - - /// this delegate should only be accessed on the connections `EventLoop` - weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate? - - init(configuration: PostgresConnection.InternalConfiguration, - logger: Logger, - configureSSLCallback: ((Channel) throws -> Void)?) - { + + private var listenState: ListenStateMachine + + init( + configuration: PostgresConnection.InternalConfiguration, + eventLoop: EventLoop, + logger: Logger, + configureSSLCallback: ((Channel) throws -> Void)? + ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) + self.eventLoop = eventLoop + self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -41,12 +41,16 @@ final class PostgresChannelHandler: ChannelDuplexHandler { #if DEBUG /// for testing purposes only - init(configuration: PostgresConnection.InternalConfiguration, - state: ConnectionStateMachine = .init(.initialized), - logger: Logger = .psqlNoOpLogger, - configureSSLCallback: ((Channel) throws -> Void)?) - { + init( + configuration: PostgresConnection.InternalConfiguration, + eventLoop: EventLoop, + state: ConnectionStateMachine = .init(.initialized), + logger: Logger = .psqlNoOpLogger, + configureSSLCallback: ((Channel) throws -> Void)? + ) { self.state = state + self.eventLoop = eventLoop + self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -194,8 +198,46 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let task = self.unwrapOutboundIn(data) - let action = self.state.enqueue(task: task) + let handlerTask = self.unwrapOutboundIn(data) + let psqlTask: PSQLTask + + switch handlerTask { + case .closeCommand(let command): + psqlTask = .closeCommand(command) + case .extendedQuery(let query): + psqlTask = .extendedQuery(query) + case .preparedStatement(let statement): + psqlTask = .preparedStatement(statement) + + case .startListening(let listener): + switch self.listenState.startListening(listener) { + case .startListening(let channel): + psqlTask = self.makeStartListeningQuery(channel: channel, context: context) + + case .none: + return + + case .succeedListenStart(let listener): + listener.startListeningSucceeded(handler: self) + return + } + + case .cancelListening(let channel, let id): + switch self.listenState.cancelNotificationListener(channel: channel, id: id) { + case .none: + return + + case .stopListening(let channel, let listener): + psqlTask = self.makeUnlistenQuery(channel: channel, context: context) + listener.failed(CancellationError()) + + case .cancelListener(let listener): + listener.failed(CancellationError()) + return + } + } + + let action = self.state.enqueue(task: psqlTask) self.run(action, with: context) } @@ -223,9 +265,34 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } + // MARK: Listening + + func cancelNotificationListener(channel: String, id: Int) { + self.eventLoop.preconditionInEventLoop() + + switch self.listenState.cancelNotificationListener(channel: channel, id: id) { + case .cancelListener(let listener): + listener.cancelled() + + case .stopListening(let channel, cancelListener: let listener): + listener.cancelled() + + guard let context = self.handlerContext else { + return + } + + let query = self.makeUnlistenQuery(channel: channel, context: context) + let action = self.state.enqueue(task: query) + self.run(action, with: context) + + case .none: + break + } + } + // MARK: Channel handler actions - func run(_ action: ConnectionStateMachine.ConnectionAction, with context: ChannelHandlerContext) { + private func run(_ action: ConnectionStateMachine.ConnectionAction, with context: ChannelHandlerContext) { self.logger.trace("Run action", metadata: [.connectionAction: "\(action)"]) switch action { @@ -333,16 +400,14 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.closeConnectionAndCleanup(cleanupContext, context: context) } case .forwardNotificationToListeners(let notification): - self.notificationDelegate?.notificationReceived(notification) + self.forwardNotificationToListeners(notification, context: context) } } // MARK: - Private Methods - private func connected(context: ChannelHandlerContext) { - let action = self.state.connected(tls: .init(self.configuration.tls)) - self.run(action, with: context) } @@ -362,8 +427,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func sendPasswordMessage( mode: PasswordAuthencationMode, authContext: AuthContext, - context: ChannelHandlerContext) - { + context: ChannelHandlerContext + ) { switch mode { case .md5(let salt): let hash1 = (authContext.password ?? "") + authContext.username @@ -407,7 +472,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context: ChannelHandlerContext ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") - self.encoder.parse(preparedStatementName: statementName, query: query, parameters: []) self.encoder.describePreparedStatement(statementName) self.encoder.sync() @@ -485,11 +549,16 @@ final class PostgresChannelHandler: ChannelDuplexHandler { cleanup.tasks.forEach { task in task.failWithError(cleanup.error) } - - // 2. fire an error + + // 2. stop all listeners + for listener in self.listenState.fail(cleanup.error) { + listener.failed(cleanup.error) + } + + // 3. fire an error context.fireErrorCaught(cleanup.error) - // 3. close the connection or fire channel inactive + // 4. close the connection or fire channel inactive switch cleanup.action { case .close: context.close(mode: .all, promise: cleanup.closePromise) @@ -498,6 +567,105 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context.fireChannelInactive() } } + + private func makeStartListeningQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { + let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) + let query = ExtendedQueryContext( + query: PostgresQuery(unsafeSQL: "LISTEN \(channel);"), + logger: self.logger, + promise: promise + ) + promise.futureResult.whenComplete { result in + self.startListenCompleted(result, for: channel, context: context) + } + + return .extendedQuery(query) + } + + private func startListenCompleted(_ result: Result, for channel: String, context: ChannelHandlerContext) { + switch result { + case .success: + switch self.listenState.startListeningSucceeded(channel: channel) { + case .activateListeners(let listeners): + for list in listeners { + list.startListeningSucceeded(handler: self) + } + + case .stopListening: + let task = self.makeUnlistenQuery(channel: channel, context: context) + let action = self.state.enqueue(task: task) + self.run(action, with: context) + } + + case .failure(let error): + let finalError: PSQLError + if var psqlError = error as? PSQLError { + psqlError.code = .listenFailed + finalError = psqlError + } else { + var psqlError = PSQLError(code: .listenFailed) + psqlError.underlying = error + finalError = psqlError + } + let listeners = self.listenState.startListeningFailed(channel: channel, error: finalError) + for list in listeners { + list.failed(finalError) + } + } + } + + private func makeUnlistenQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { + let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) + let query = ExtendedQueryContext( + query: PostgresQuery(unsafeSQL: "UNLISTEN \(channel);"), + logger: self.logger, + promise: promise + ) + promise.futureResult.whenComplete { result in + self.stopListenCompleted(result, for: channel, context: context) + } + + return .extendedQuery(query) + } + + private func stopListenCompleted( + _ result: Result, + for channel: String, + context: ChannelHandlerContext + ) { + switch result { + case .success: + switch self.listenState.stopListeningSucceeded(channel: channel) { + case .none: + break + + case .startListening: + let task = self.makeStartListeningQuery(channel: channel, context: context) + let action = self.state.enqueue(task: task) + self.run(action, with: context) + } + + case .failure(let error): + let action = self.state.errorHappened(.unlistenError(underlying: error)) + self.run(action, with: context) + } + } + + private func forwardNotificationToListeners( + _ notification: PostgresBackendMessage.NotificationResponse, + context: ChannelHandlerContext + ) { + switch self.listenState.notificationReceived(channel: notification.channel) { + case .none: + break + + case .notify(let listeners): + for listener in listeners { + listener.notificationReceived(notification) + } + } + } + } extension PostgresChannelHandler: PSQLRowsDataSource { @@ -578,16 +746,3 @@ extension ConnectionStateMachine.TLSConfiguration { } } } - -extension PostgresChannelHandler { - convenience init( - configuration: PostgresConnection.InternalConfiguration, - configureSSLCallback: ((Channel) throws -> Void)?) - { - self.init( - configuration: configuration, - logger: .psqlNoOpLogger, - configureSSLCallback: configureSSLCallback - ) - } -} diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift index 3963bd62..2a7ec9f1 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift @@ -14,7 +14,12 @@ enum PostgresFrontendMessage: Equatable { var preparedStatementName: String /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. - var bind: PostgresBindings + var parameterFormats: [PostgresFormat] + + /// The number of parameter values that follow (possibly zero). This must match the number of parameters needed by the query. + var parameters: [ByteBuffer?] + + var resultColumnFormats: [PostgresFormat] } struct Cancel: Equatable { diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift new file mode 100644 index 00000000..735c01b0 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -0,0 +1,22 @@ + +public struct PostgresNotification: Sendable { + public let payload: String +} + +public struct PostgresNotificationSequence: AsyncSequence, Sendable { + public typealias Element = PostgresNotification + + let base: AsyncThrowingStream + + public func makeAsyncIterator() -> AsyncIterator { + AsyncIterator(base: self.base.makeAsyncIterator()) + } + + public struct AsyncIterator: AsyncIteratorProtocol { + var base: AsyncThrowingStream.AsyncIterator + + public mutating func next() async throws -> Element? { + try await self.base.next() + } + } +} diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index ff9773f5..10970b26 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -5,7 +5,7 @@ extension PSQLError { switch self.code.base { case .queryCancelled: return self - case .server: + case .server, .listenFailed: guard let serverInfo = self.serverInfo else { return self } @@ -43,6 +43,8 @@ extension PSQLError { return PostgresError.connectionClosed case .connectionError: return self.underlying ?? self + case .unlistenFailed: + return self.underlying ?? self case .uncleanShutdown: return PostgresError.protocol("Unexpected connection close") } diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 7a45c5c0..f68ef1f3 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -224,6 +224,29 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testListenAndNotify() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await self.withTestConnection(on: eventLoop) { connection in + let stream = try await connection.listen("foo") + var iterator = stream.makeAsyncIterator() + + try await self.withTestConnection(on: eventLoop) { other in + try await other.query(#"NOTIFY foo, 'bar';"#, logger: .psqlTest) + + try await other.query(#"NOTIFY foo, 'foo';"#, logger: .psqlTest) + } + + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "bar") + + let second = try await iterator.next() + XCTAssertEqual(second?.payload, "foo") + } + } + #if canImport(Network) func testSelect10kRowsNetworkFramework() async throws { let eventLoopGroup = NIOTSEventLoopGroup() diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 342907ea..b9677000 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -125,17 +125,90 @@ extension PostgresFrontendMessage { static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PostgresFrontendMessage { switch messageID { case .bind: - preconditionFailure("TODO: Unimplemented") + guard let portalName = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + guard let preparedStatementName = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + guard let parameterFormatCount = buffer.readInteger(as: UInt16.self) else { + preconditionFailure("TODO: Unimplemented") + } + + let parameterFormats = (0.. ByteBuffer? in + let length = buffer.readInteger(as: UInt16.self) + switch length { + case .some(..<0): + return nil + case .some(0...): + return buffer.readSlice(length: Int(length!)) + default: + preconditionFailure("TODO: Unimplemented") + } + } + + guard let resultColumnFormatCount = buffer.readInteger(as: UInt16.self) else { + preconditionFailure("TODO: Unimplemented") + } + + let resultColumnFormats = (0.. (PostgresConnection, NIOAsyncTestingChannel) { + let eventLoop = NIOAsyncTestingEventLoop() + let channel = await NIOAsyncTestingChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + ], loop: eventLoop) + try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) + + let configuration = PostgresConnection.Configuration( + establishedChannel: channel, + username: "username", + password: "postgres", + database: "database" + ) + + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger) + let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false)))) + try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) + try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + let connection = try await connectionPromise + + self.addTeardownBlock { + try await connection.close() + } + + return (connection, channel) + } +} + +extension NIOAsyncTestingChannel { + + func waitForUnpreparedRequest() async throws -> UnpreparedRequest { + let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .parse(let parse) = parse, + case .describe(let describe) = describe, + case .bind(let bind) = bind, + case .execute(let execute) = execute, + case .sync = sync + else { + fatalError() + } + + return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) + } +} + +struct UnpreparedRequest { + var parse: PostgresFrontendMessage.Parse + var describe: PostgresFrontendMessage.Describe + var bind: PostgresFrontendMessage.Bind + var execute: PostgresFrontendMessage.Execute +} From 5ffc8fc811f3e36317089031f80f15a4d31b5c44 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 7 Aug 2023 05:03:31 -0500 Subject: [PATCH 167/292] Upgrade CI (#382) --- .github/workflows/test.yml | 100 +++++++++++------- .../PostgresNIO/Docs.docc/images/article.svg | 1 + .../Docs.docc/images/vapor-postgres-logo.svg | 36 +++++++ .../PostgresNIO/Docs.docc/theme-settings.json | 46 ++++++++ 4 files changed, 143 insertions(+), 40 deletions(-) create mode 100644 Sources/PostgresNIO/Docs.docc/images/article.svg create mode 100644 Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg create mode 100644 Sources/PostgresNIO/Docs.docc/theme-settings.json diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 24821c77..2da05f81 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,49 +17,52 @@ jobs: strategy: fail-fast: false matrix: - container: + swift-image: - swift:5.6-focal - swift:5.7-jammy - swift:5.8-jammy - swiftlang/swift:nightly-5.9-jammy - swiftlang/swift:nightly-main-jammy - container: ${{ matrix.container }} + include: + - swift-image: swift:5.8-jammy + code-coverage: true + container: ${{ matrix.swift-image }} runs-on: ubuntu-latest steps: - - name: Note Swift version - if: ${{ contains(matrix.swiftver, 'nightly') }} - run: | - echo "SWIFT_PLATFORM=$(. /etc/os-release && echo "${ID}${VERSION_ID}")" >>"${GITHUB_ENV}" - echo "SWIFT_VERSION=$(cat /.swift_tag)" >>"${GITHUB_ENV}" - name: Display OS and Swift versions + shell: bash run: | - printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version + [[ -z "${SWIFT_PLATFORM}" ]] && SWIFT_PLATFORM="$(. /etc/os-release && echo "${ID}${VERSION_ID}")" + [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" + printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" + swift --version - name: Check out package uses: actions/checkout@v3 - - name: Run unit tests with code coverage and Thread Sanitizer - run: swift test --filter=^PostgresNIOTests --sanitize=thread --enable-code-coverage - - name: Submit coverage report to Codecov.io - uses: vapor/swift-codecov-action@v0.2 - with: - cc_env_vars: 'SWIFT_VERSION,SWIFT_PLATFORM,RUNNER_OS,RUNNER_ARCH' - cc_fail_ci_if_error: false + - name: Run unit tests with Thread Sanitizer + env: + CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} + run: | + swift test --filter=^PostgresNIOTests --sanitize=thread ${CODE_COVERAGE} + - name: Submit code coverage + if: ${{ matrix.code-coverage }} + uses: vapor/swift-codecov-action@v0.2 linux-integration-and-dependencies: if: github.event_name == 'pull_request' strategy: fail-fast: false matrix: - dbimage: + postgres-image: - postgres:15 - postgres:13 - postgres:11 include: - - dbimage: postgres:15 - dbauth: scram-sha-256 - - dbimage: postgres:13 - dbauth: md5 - - dbimage: postgres:11 - dbauth: trust + - postgres-image: postgres:15 + postgres-auth: scram-sha-256 + - postgres-image: postgres:13 + postgres-auth: md5 + - postgres-image: postgres:11 + postgres-auth: trust container: image: swift:5.8-jammy volumes: [ 'pgrunshare:/var/run/postgresql' ] @@ -79,29 +82,31 @@ jobs: POSTGRES_HOSTNAME_A: 'psql-a' POSTGRES_HOSTNAME_B: 'psql-b' POSTGRES_SOCKET: '/var/run/postgresql/.s.PGSQL.5432' - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} services: psql-a: - image: ${{ matrix.dbimage }} + image: ${{ matrix.postgres-image }} volumes: [ 'pgrunshare:/var/run/postgresql' ] env: POSTGRES_USER: 'test_username' POSTGRES_DB: 'test_database' POSTGRES_PASSWORD: 'test_password' - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} - POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} + POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.postgres-auth }} psql-b: - image: ${{ matrix.dbimage }} + image: ${{ matrix.postgres-image }} volumes: [ 'pgrunshare:/var/run/postgresql' ] env: POSTGRES_USER: 'test_username' POSTGRES_DB: 'test_database' POSTGRES_PASSWORD: 'test_password' - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} - POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.dbauth }} + POSTGRES_HOST_AUTH_METHOD: ${{ matrix.postgres-auth }} + POSTGRES_INITDB_ARGS: --auth-host=${{ matrix.postgres-auth }} steps: - name: Display OS and Swift versions run: | + [[ -z "${SWIFT_PLATFORM}" ]] && SWIFT_PLATFORM="$(. /etc/os-release && echo "${ID}${VERSION_ID}")" + [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package uses: actions/checkout@v3 @@ -128,33 +133,34 @@ jobs: strategy: fail-fast: false matrix: - dbimage: + postgres-formula: # Only test one version on macOS, let Linux do the rest - postgresql@14 - dbauth: + postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 - xcode: - - latest-stable + xcode-version: + - '~14.3' + - '15.0-beta' runs-on: macos-13 env: POSTGRES_HOSTNAME: 127.0.0.1 POSTGRES_USER: 'test_username' POSTGRES_PASSWORD: 'test_password' POSTGRES_DB: 'postgres' - POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }} + POSTGRES_AUTH_METHOD: ${{ matrix.postgres-auth }} POSTGRES_SOCKET: '/tmp/.s.PGSQL.5432' - POSTGRES_VERSION: ${{ matrix.dbimage }} + POSTGRES_FORMULA: ${{ matrix.postgres-formula }} steps: - name: Select latest available Xcode uses: maxim-lobanov/setup-xcode@v1 with: - xcode-version: ${{ matrix.xcode }} + xcode-version: ${{ matrix.xcode-version }} - name: Install Postgres, setup DB and auth, and wait for server start run: | - export PATH="$(brew --prefix)/opt/${POSTGRES_VERSION}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - (brew unlink postgresql || true) && brew install "${POSTGRES_VERSION}" && brew link --force "${POSTGRES_VERSION}" - initdb --locale=C --auth-host "${POSTGRES_HOST_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") + export PATH="$(brew --prefix)/opt/${POSTGRES_FORMULA}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test + (brew unlink postgresql || true) && brew install "${POSTGRES_FORMULA}" && brew link --force "${POSTGRES_FORMULA}" + initdb --locale=C --auth-host "${POSTGRES_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait timeout-minutes: 2 - name: Checkout code @@ -165,7 +171,7 @@ jobs: api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest - container: swift:5.8-jammy + container: swift:jammy steps: - name: Checkout uses: actions/checkout@v3 @@ -177,3 +183,17 @@ jobs: - name: API breaking changes run: swift package diagnose-api-breaking-changes origin/main + gh-codeql: + runs-on: ubuntu-latest + permissions: { security-events: write } + steps: + - name: Check out code + uses: actions/checkout@v3 + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: swift + - name: Perform build + run: swift build + - name: Run CodeQL analyze + uses: github/codeql-action/analyze@v2 diff --git a/Sources/PostgresNIO/Docs.docc/images/article.svg b/Sources/PostgresNIO/Docs.docc/images/article.svg new file mode 100644 index 00000000..3dc6a66c --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/images/article.svg @@ -0,0 +1 @@ + diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg new file mode 100644 index 00000000..e1c1223b --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg @@ -0,0 +1,36 @@ + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json new file mode 100644 index 00000000..c6ce054e --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -0,0 +1,46 @@ +{ + "theme": { + "aside": { + "border-radius": "6px", + "border-style": "double", + "border-width": "3px" + }, + "border-radius": "0", + "button": { + "border-radius": "16px", + "border-width": "1px", + "border-style": "solid" + }, + "code": { + "border-radius": "16px", + "border-width": "1px", + "border-style": "solid" + }, + "color": { + "fill": { + "dark": "rgb(20, 20, 22)", + "light": "rgb(255, 255, 255)" + }, + "psql-blue": "#336791", + "documentation-intro-fill": "radial-gradient(circle at top, var(--color-documentation-intro-accent) 30%, #1f1d1f 100%)", + "documentation-intro-accent": "var(--color-psql-blue)", + "documentation-intro-accent-outer": { + "dark": "rgb(255, 255, 255)", + "light": "rgb(51, 51, 51)" + }, + "documentation-intro-accent-inner": { + "dark": "rgb(51, 51, 51)", + "light": "rgb(255, 255, 255)" + } + }, + "icons": { + "technology": "/postgresnio/images/vapor-postgres-logo.svg", + "article": "/postgresnio/images/article.svg" + } + }, + "features": { + "quickNavigation": { + "enable": true + } + } +} From 329ce83ee4d45c063b908f3f66efb49c930ac5f6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 7 Aug 2023 12:16:52 +0200 Subject: [PATCH 168/292] Cleanup PostgresBackendMessage (#384) --- .../ConnectionStateMachine.swift | 13 +---- .../New/Messages/Authentication.swift | 37 ++------------ .../New/Messages/BackendKeyData.swift | 2 +- .../PostgresNIO/New/Messages/DataRow.swift | 2 +- .../New/Messages/ErrorResponse.swift | 4 +- .../New/Messages/NotificationResponse.swift | 2 +- .../New/Messages/ParameterDescription.swift | 2 +- .../New/Messages/ParameterStatus.swift | 2 +- .../New/Messages/ReadyForQuery.swift | 34 ++----------- .../New/Messages/RowDescription.swift | 4 +- .../New/PostgresBackendMessage.swift | 2 +- .../New/PostgresChannelHandler.swift | 8 +-- .../AuthenticationStateMachineTests.swift | 16 +++--- .../ConnectionStateMachineTests.swift | 4 +- .../PSQLBackendMessage+Equatable.swift | 49 ------------------- .../PSQLBackendMessageEncoder.swift | 8 +-- .../New/Messages/AuthenticationTests.swift | 22 +++++---- .../New/PSQLBackendMessageTests.swift | 5 +- .../New/PostgresChannelHandlerTests.swift | 20 +++----- 19 files changed, 57 insertions(+), 179 deletions(-) delete mode 100644 Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 761ba5f2..93312c86 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1258,18 +1258,7 @@ struct AuthContext: Equatable, CustomDebugStringConvertible { enum PasswordAuthencationMode: Equatable { case cleartext - case md5(salt: (UInt8, UInt8, UInt8, UInt8)) - - static func ==(lhs: Self, rhs: Self) -> Bool { - switch (lhs, rhs) { - case (.cleartext, .cleartext): - return true - case (.md5(let lhs), .md5(let rhs)): - return lhs == rhs - default: - return false - } - } + case md5(salt: UInt32) } extension ConnectionStateMachine.State: CustomDebugStringConvertible { diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift index bd0d2e57..eff62e91 100644 --- a/Sources/PostgresNIO/New/Messages/Authentication.swift +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -2,10 +2,10 @@ import NIOCore extension PostgresBackendMessage { - enum Authentication: PayloadDecodable { + enum Authentication: PayloadDecodable, Hashable { case ok case kerberosV5 - case md5(salt: (UInt8, UInt8, UInt8, UInt8)) + case md5(salt: UInt32) case plaintext case scmCredential case gss @@ -26,7 +26,7 @@ extension PostgresBackendMessage { case 3: return .plaintext case 5: - guard let salt = buffer.readMultipleIntegers(endianness: .big, as: (UInt8, UInt8, UInt8, UInt8).self) else { + guard let salt = buffer.readInteger(as: UInt32.self) else { throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(4, actual: buffer.readableBytes) } return .md5(salt: salt) @@ -61,37 +61,6 @@ extension PostgresBackendMessage { } } -extension PostgresBackendMessage.Authentication: Equatable { - static func ==(lhs: Self, rhs: Self) -> Bool { - switch (lhs, rhs) { - case (.ok, .ok): - return true - case (.kerberosV5, .kerberosV5): - return true - case (.md5(let lhs), .md5(let rhs)): - return lhs == rhs - case (.plaintext, .plaintext): - return true - case (.scmCredential, .scmCredential): - return true - case (.gss, .gss): - return true - case (.sspi, .sspi): - return true - case (.gssContinue(let lhs), .gssContinue(let rhs)): - return lhs == rhs - case (.sasl(let lhs), .sasl(let rhs)): - return lhs == rhs - case (.saslContinue(let lhs), .saslContinue(let rhs)): - return lhs == rhs - case (.saslFinal(let lhs), .saslFinal(let rhs)): - return lhs == rhs - default: - return false - } - } -} - extension PostgresBackendMessage.Authentication: CustomDebugStringConvertible { var debugDescription: String { switch self { diff --git a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift index 498c5110..31a676d2 100644 --- a/Sources/PostgresNIO/New/Messages/BackendKeyData.swift +++ b/Sources/PostgresNIO/New/Messages/BackendKeyData.swift @@ -2,7 +2,7 @@ import NIOCore extension PostgresBackendMessage { - struct BackendKeyData: PayloadDecodable, Equatable { + struct BackendKeyData: PayloadDecodable, Hashable { let processID: Int32 let secretKey: Int32 diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index b181e600..491e10dc 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -9,7 +9,7 @@ import NIOCore /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler @usableFromInline -struct DataRow: Sendable, PostgresBackendMessage.PayloadDecodable, Equatable { +struct DataRow: Sendable, PostgresBackendMessage.PayloadDecodable, Hashable { @usableFromInline var columnCount: Int16 @usableFromInline diff --git a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift index 818c1ebf..d0bb6044 100644 --- a/Sources/PostgresNIO/New/Messages/ErrorResponse.swift +++ b/Sources/PostgresNIO/New/Messages/ErrorResponse.swift @@ -80,7 +80,7 @@ extension PostgresBackendMessage { case routine = 0x52 /// R } - struct ErrorResponse: PSQLMessageNotice, PayloadDecodable, Equatable { + struct ErrorResponse: PSQLMessageNotice, PayloadDecodable, Hashable { let fields: [PostgresBackendMessage.Field: String] init(fields: [PostgresBackendMessage.Field: String]) { @@ -88,7 +88,7 @@ extension PostgresBackendMessage { } } - struct NoticeResponse: PSQLMessageNotice, PayloadDecodable, Equatable { + struct NoticeResponse: PSQLMessageNotice, PayloadDecodable, Hashable { let fields: [PostgresBackendMessage.Field: String] init(fields: [PostgresBackendMessage.Field: String]) { diff --git a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift index 5cd9422e..01b9ab4a 100644 --- a/Sources/PostgresNIO/New/Messages/NotificationResponse.swift +++ b/Sources/PostgresNIO/New/Messages/NotificationResponse.swift @@ -2,7 +2,7 @@ import NIOCore extension PostgresBackendMessage { - struct NotificationResponse: PayloadDecodable, Equatable { + struct NotificationResponse: PayloadDecodable, Hashable { let backendPID: Int32 let channel: String let payload: String diff --git a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift index 1ccc91e5..4d12b1b6 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterDescription.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterDescription.swift @@ -2,7 +2,7 @@ import NIOCore extension PostgresBackendMessage { - struct ParameterDescription: PayloadDecodable, Equatable { + struct ParameterDescription: PayloadDecodable, Hashable { /// Specifies the object ID of the parameter data type. var dataTypes: [PostgresDataType] diff --git a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift index 4ffcbe12..52d07e01 100644 --- a/Sources/PostgresNIO/New/Messages/ParameterStatus.swift +++ b/Sources/PostgresNIO/New/Messages/ParameterStatus.swift @@ -2,7 +2,7 @@ import NIOCore extension PostgresBackendMessage { - struct ParameterStatus: PayloadDecodable, Equatable { + struct ParameterStatus: PayloadDecodable, Hashable { /// The name of the run-time parameter being reported. var parameter: String diff --git a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift index a300f714..41af1b60 100644 --- a/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift +++ b/Sources/PostgresNIO/New/Messages/ReadyForQuery.swift @@ -1,37 +1,11 @@ import NIOCore extension PostgresBackendMessage { - enum TransactionState: PayloadDecodable, RawRepresentable { - typealias RawValue = UInt8 - - case idle - case inTransaction - case inFailedTransaction - - init?(rawValue: UInt8) { - switch rawValue { - case UInt8(ascii: "I"): - self = .idle - case UInt8(ascii: "T"): - self = .inTransaction - case UInt8(ascii: "E"): - self = .inFailedTransaction - default: - return nil - } - } + enum TransactionState: UInt8, PayloadDecodable, Hashable { + case idle = 73 // ascii: I + case inTransaction = 84 // ascii: T + case inFailedTransaction = 69 // ascii: E - var rawValue: Self.RawValue { - switch self { - case .idle: - return UInt8(ascii: "I") - case .inTransaction: - return UInt8(ascii: "T") - case .inFailedTransaction: - return UInt8(ascii: "E") - } - } - static func decode(from buffer: inout ByteBuffer) throws -> Self { let value = try buffer.throwingReadInteger(as: UInt8.self) guard let state = Self.init(rawValue: value) else { diff --git a/Sources/PostgresNIO/New/Messages/RowDescription.swift b/Sources/PostgresNIO/New/Messages/RowDescription.swift index 66c71215..766d06e9 100644 --- a/Sources/PostgresNIO/New/Messages/RowDescription.swift +++ b/Sources/PostgresNIO/New/Messages/RowDescription.swift @@ -9,13 +9,13 @@ import NIOCore /// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick /// the Swift compiler. @usableFromInline -struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Equatable { +struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Hashable { /// Specifies the object ID of the parameter data type. @usableFromInline var columns: [Column] @usableFromInline - struct Column: Equatable, Sendable { + struct Column: Hashable, Sendable { /// The field name. @usableFromInline var name: String diff --git a/Sources/PostgresNIO/New/PostgresBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift index ecccd1e9..71c3cacd 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -20,7 +20,7 @@ protocol PSQLMessagePayloadDecodable { /// /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) -enum PostgresBackendMessage { +enum PostgresBackendMessage: Hashable { typealias PayloadDecodable = PSQLMessagePayloadDecodable diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 4470e802..32c35927 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -437,10 +437,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { var hash2 = [UInt8]() hash2.reserveCapacity(pwdhash.count + 4) hash2.append(contentsOf: pwdhash) - hash2.append(salt.0) - hash2.append(salt.1) - hash2.append(salt.2) - hash2.append(salt.3) + var saltNetworkOrder = salt.bigEndian + withUnsafeBytes(of: &saltNetworkOrder) { ptr in + hash2.append(contentsOf: ptr) + } let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest() self.encoder.password(hash.utf8) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index 87478e63..b06b69ab 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -19,8 +19,8 @@ class AuthenticationStateMachineTests: XCTestCase { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) - + let salt: UInt32 = 0x00_01_02_03 + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) @@ -30,8 +30,8 @@ class AuthenticationStateMachineTests: XCTestCase { let authContext = AuthContext(username: "test", password: nil, database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) - + let salt: UInt32 = 0x00_01_02_03 + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .authMechanismRequiresPassword, closePromise: nil))) @@ -49,8 +49,8 @@ class AuthenticationStateMachineTests: XCTestCase { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) - + let salt: UInt32 = 0x00_01_02_03 + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) let fields: [PostgresBackendMessage.Field: String] = [ @@ -107,12 +107,12 @@ class AuthenticationStateMachineTests: XCTestCase { } func testUnexpectedMessagesAfterPasswordSent() { - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) + let salt: UInt32 = 0x00_01_02_03 var buffer = ByteBuffer() buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8]) let unexpected: [PostgresBackendMessage.Authentication] = [ .kerberosV5, - .md5(salt: (0, 1, 2, 3)), + .md5(salt: salt), .plaintext, .scmCredential, .gss, diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 289665fb..d5d4ecb1 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -23,7 +23,7 @@ class ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.sslHandlerAdded(), .wait) XCTAssertEqual(state.sslEstablished(), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - let salt: (UInt8, UInt8, UInt8, UInt8) = (0,1,2,3) + let salt: UInt32 = 0x00_01_02_03 XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) } @@ -154,7 +154,7 @@ class ConnectionStateMachineTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let authContext = AuthContext(username: "test", password: "abc123", database: "test") - let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3) + let salt: UInt32 = 0x00_01_02_03 let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift deleted file mode 100644 index c459ffeb..00000000 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift +++ /dev/null @@ -1,49 +0,0 @@ -@testable import PostgresNIO - -extension PostgresBackendMessage: Equatable { - - public static func ==(lhs: Self, rhs: Self) -> Bool { - switch (lhs, rhs) { - case (.authentication(let lhs), .authentication(let rhs)): - return lhs == rhs - case (.backendKeyData(let lhs), .backendKeyData(let rhs)): - return lhs == rhs - case (.bindComplete, bindComplete): - return true - case (.closeComplete, closeComplete): - return true - case (.commandComplete(let lhs), commandComplete(let rhs)): - return lhs == rhs - case (.dataRow(let lhs), dataRow(let rhs)): - return lhs == rhs - case (.emptyQueryResponse, emptyQueryResponse): - return true - case (.error(let lhs), error(let rhs)): - return lhs == rhs - case (.noData, noData): - return true - case (.notice(let lhs), notice(let rhs)): - return lhs == rhs - case (.notification(let lhs), .notification(let rhs)): - return lhs == rhs - case (.parameterDescription(let lhs), parameterDescription(let rhs)): - return lhs == rhs - case (.parameterStatus(let lhs), parameterStatus(let rhs)): - return lhs == rhs - case (.parseComplete, parseComplete): - return true - case (.portalSuspended, portalSuspended): - return true - case (.readyForQuery(let lhs), readyForQuery(let rhs)): - return lhs == rhs - case (.rowDescription(let lhs), rowDescription(let rhs)): - return lhs == rhs - case (.sslSupported, sslSupported): - return true - case (.sslUnsupported, sslUnsupported): - return true - default: - return false - } - } -} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index eea7dec3..e51c14f9 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -9,7 +9,7 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { /// - parameters: /// - data: The data to encode into a `ByteBuffer`. /// - out: The `ByteBuffer` into which we want to encode. - func encode(data message: PostgresBackendMessage, out buffer: inout ByteBuffer) throws { + func encode(data message: PostgresBackendMessage, out buffer: inout ByteBuffer) { switch message { case .authentication(let authentication): self.encode(messageID: message.id, payload: authentication, into: &buffer) @@ -144,11 +144,7 @@ extension PostgresBackendMessage.Authentication: PSQLMessagePayloadEncodable { buffer.writeInteger(Int32(3)) case .md5(salt: let salt): - buffer.writeInteger(Int32(5)) - buffer.writeInteger(salt.0) - buffer.writeInteger(salt.1) - buffer.writeInteger(salt.2) - buffer.writeInteger(salt.3) + buffer.writeMultipleIntegers(Int32(5), salt) case .scmCredential: buffer.writeInteger(Int32(6)) diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 31a21a91..06e39aae 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -11,35 +11,37 @@ class AuthenticationTests: XCTestCase { let encoder = PSQLBackendMessageEncoder() // add ok - XCTAssertNoThrow(try encoder.encode(data: .authentication(.ok), out: &buffer)) + encoder.encode(data: .authentication(.ok), out: &buffer) expected.append(.authentication(.ok)) // add kerberos - XCTAssertNoThrow(try encoder.encode(data: .authentication(.kerberosV5), out: &buffer)) + encoder.encode(data: .authentication(.kerberosV5), out: &buffer) expected.append(.authentication(.kerberosV5)) // add plaintext - XCTAssertNoThrow(try encoder.encode(data: .authentication(.plaintext), out: &buffer)) + encoder.encode(data: .authentication(.plaintext), out: &buffer) expected.append(.authentication(.plaintext)) // add md5 - XCTAssertNoThrow(try encoder.encode(data: .authentication(.md5(salt: (1, 2, 3, 4))), out: &buffer)) - expected.append(.authentication(.md5(salt: (1, 2, 3, 4)))) - + let salt: UInt32 = 0x01_02_03_04 + encoder.encode(data: .authentication(.md5(salt: salt)), out: &buffer) + expected.append(.authentication(.md5(salt: salt))) + // add scm credential - XCTAssertNoThrow(try encoder.encode(data: .authentication(.scmCredential), out: &buffer)) + encoder.encode(data: .authentication(.scmCredential), out: &buffer) expected.append(.authentication(.scmCredential)) // add gss - XCTAssertNoThrow(try encoder.encode(data: .authentication(.gss), out: &buffer)) + encoder.encode(data: .authentication(.gss), out: &buffer) expected.append(.authentication(.gss)) // add sspi - XCTAssertNoThrow(try encoder.encode(data: .authentication(.sspi), out: &buffer)) + encoder.encode(data: .authentication(.sspi), out: &buffer) expected.append(.authentication(.sspi)) XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, expected)], - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + )) } } diff --git a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift index 10e8503a..195c7fb4 100644 --- a/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLBackendMessageTests.swift @@ -256,11 +256,12 @@ class PSQLBackendMessageTests: XCTestCase { } func testDebugDescription() { + let salt: UInt32 = 0x00_01_02_03 XCTAssertEqual("\(PostgresBackendMessage.authentication(.ok))", ".authentication(.ok)") XCTAssertEqual("\(PostgresBackendMessage.authentication(.kerberosV5))", ".authentication(.kerberosV5)") - XCTAssertEqual("\(PostgresBackendMessage.authentication(.md5(salt: (0, 1, 2, 3))))", - ".authentication(.md5(salt: (0, 1, 2, 3)))") + XCTAssertEqual("\(PostgresBackendMessage.authentication(.md5(salt: salt)))", + ".authentication(.md5(salt: \(salt)))") XCTAssertEqual("\(PostgresBackendMessage.authentication(.plaintext))", ".authentication(.plaintext)") XCTAssertEqual("\(PostgresBackendMessage.authentication(.scmCredential))", diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 4484d6a4..5388e8b5 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -152,19 +152,19 @@ class PostgresChannelHandlerTests: XCTestCase { let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop, state: state, configureSSLCallback: nil) let embedded = EmbeddedChannel(handlers: [ ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), handler ], loop: self.eventLoop) embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil) XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) + let salt: UInt32 = 0x00_01_02_03 + + let encoder = PSQLBackendMessageEncoder() + var byteBuffer = ByteBuffer() + encoder.encode(data: .authentication(.md5(salt: salt)), out: &byteBuffer) + XCTAssertNoThrow(try embedded.writeInbound(byteBuffer)) - XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.md5(salt: (0,1,2,3))))) - - var message: PostgresFrontendMessage? - XCTAssertNoThrow(message = try embedded.readOutbound(as: PostgresFrontendMessage.self)) - - XCTAssertEqual(message, .password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a"))) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a"))) } func testRunAuthenticateCleartext() { @@ -187,11 +187,7 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters()))) XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.plaintext))) - - var message: PostgresFrontendMessage? - XCTAssertNoThrow(message = try embedded.readOutbound(as: PostgresFrontendMessage.self)) - - XCTAssertEqual(message, .password(.init(value: password))) + XCTAssertEqual(try embedded.readOutbound(as: PostgresFrontendMessage.self), .password(.init(value: password))) } func testHandlerThatSendsMultipleWrongMessages() { From 0a1c54e38961a8989d37bb8ee75da38c3f7232aa Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 7 Aug 2023 16:18:42 +0200 Subject: [PATCH 169/292] PostgresBackendMessage.ID should be backed by UInt8 directly (#386) --- .../New/PostgresBackendMessage.swift | 160 +++--------------- .../New/PostgresBackendMessageDecoder.swift | 12 +- 2 files changed, 31 insertions(+), 141 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift index 71c3cacd..792beec3 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -46,141 +46,31 @@ enum PostgresBackendMessage: Hashable { } extension PostgresBackendMessage { - enum ID: RawRepresentable, Equatable { - typealias RawValue = UInt8 - - case authentication - case backendKeyData - case bindComplete - case closeComplete - case commandComplete - case copyData - case copyDone - case copyInResponse - case copyOutResponse - case copyBothResponse - case dataRow - case emptyQueryResponse - case error - case functionCallResponse - case negotiateProtocolVersion - case noData - case noticeResponse - case notificationResponse - case parameterDescription - case parameterStatus - case parseComplete - case portalSuspended - case readyForQuery - case rowDescription - - init?(rawValue: UInt8) { - switch rawValue { - case UInt8(ascii: "R"): - self = .authentication - case UInt8(ascii: "K"): - self = .backendKeyData - case UInt8(ascii: "2"): - self = .bindComplete - case UInt8(ascii: "3"): - self = .closeComplete - case UInt8(ascii: "C"): - self = .commandComplete - case UInt8(ascii: "d"): - self = .copyData - case UInt8(ascii: "c"): - self = .copyDone - case UInt8(ascii: "G"): - self = .copyInResponse - case UInt8(ascii: "H"): - self = .copyOutResponse - case UInt8(ascii: "W"): - self = .copyBothResponse - case UInt8(ascii: "D"): - self = .dataRow - case UInt8(ascii: "I"): - self = .emptyQueryResponse - case UInt8(ascii: "E"): - self = .error - case UInt8(ascii: "V"): - self = .functionCallResponse - case UInt8(ascii: "v"): - self = .negotiateProtocolVersion - case UInt8(ascii: "n"): - self = .noData - case UInt8(ascii: "N"): - self = .noticeResponse - case UInt8(ascii: "A"): - self = .notificationResponse - case UInt8(ascii: "t"): - self = .parameterDescription - case UInt8(ascii: "S"): - self = .parameterStatus - case UInt8(ascii: "1"): - self = .parseComplete - case UInt8(ascii: "s"): - self = .portalSuspended - case UInt8(ascii: "Z"): - self = .readyForQuery - case UInt8(ascii: "T"): - self = .rowDescription - default: - return nil - } - } - - var rawValue: UInt8 { - switch self { - case .authentication: - return UInt8(ascii: "R") - case .backendKeyData: - return UInt8(ascii: "K") - case .bindComplete: - return UInt8(ascii: "2") - case .closeComplete: - return UInt8(ascii: "3") - case .commandComplete: - return UInt8(ascii: "C") - case .copyData: - return UInt8(ascii: "d") - case .copyDone: - return UInt8(ascii: "c") - case .copyInResponse: - return UInt8(ascii: "G") - case .copyOutResponse: - return UInt8(ascii: "H") - case .copyBothResponse: - return UInt8(ascii: "W") - case .dataRow: - return UInt8(ascii: "D") - case .emptyQueryResponse: - return UInt8(ascii: "I") - case .error: - return UInt8(ascii: "E") - case .functionCallResponse: - return UInt8(ascii: "V") - case .negotiateProtocolVersion: - return UInt8(ascii: "v") - case .noData: - return UInt8(ascii: "n") - case .noticeResponse: - return UInt8(ascii: "N") - case .notificationResponse: - return UInt8(ascii: "A") - case .parameterDescription: - return UInt8(ascii: "t") - case .parameterStatus: - return UInt8(ascii: "S") - case .parseComplete: - return UInt8(ascii: "1") - case .portalSuspended: - return UInt8(ascii: "s") - case .readyForQuery: - return UInt8(ascii: "Z") - case .rowDescription: - return UInt8(ascii: "T") - } - } + enum ID: UInt8, Hashable { + case authentication = 82 // ascii: R + case backendKeyData = 75 // ascii: K + case bindComplete = 50 // ascii: 2 + case closeComplete = 51 // ascii: 3 + case commandComplete = 67 // ascii: C + case copyData = 100 // ascii: d + case copyDone = 99 // ascii: c + case copyInResponse = 71 // ascii: G + case copyOutResponse = 72 // ascii: H + case copyBothResponse = 87 // ascii: W + case dataRow = 68 // ascii: D + case emptyQueryResponse = 73 // ascii: I + case error = 69 // ascii: E + case functionCallResponse = 86 // ascii: V + case negotiateProtocolVersion = 118 // ascii: v + case noData = 110 // ascii: n + case noticeResponse = 78 // ascii: N + case notificationResponse = 65 // ascii: A + case parameterDescription = 116 // ascii: t + case parameterStatus = 83 // ascii: S + case parseComplete = 49 // ascii: 1 + case portalSuspended = 115 // ascii: s + case readyForQuery = 90 // ascii: Z + case rowDescription = 84 // ascii: T } } diff --git a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift index ee7e1b84..6f6be7ec 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -107,8 +107,8 @@ struct PostgresMessageDecodingError: Error { static func withPartialError( _ partialError: PSQLPartialDecodingError, messageID: UInt8, - messageBytes: ByteBuffer) -> Self - { + messageBytes: ByteBuffer + ) -> Self { var byteBuffer = messageBytes let data = byteBuffer.readData(length: byteBuffer.readableBytes)! @@ -124,8 +124,8 @@ struct PostgresMessageDecodingError: Error { messageID: UInt8, messageBytes: ByteBuffer, file: String = #fileID, - line: Int = #line) -> Self - { + line: Int = #line + ) -> Self { var byteBuffer = messageBytes let data = byteBuffer.readData(length: byteBuffer.readableBytes)! @@ -153,8 +153,8 @@ struct PSQLPartialDecodingError: Error { value: Target.RawValue, asType: Target.Type, file: String = #fileID, - line: Int = #line) -> Self - { + line: Int = #line + ) -> Self { return PSQLPartialDecodingError( description: "Can not represent '\(value)' with type '\(asType)'.", file: file, line: line) From 220eb501f336ec3e22605e9c16dc7d8ce4251e6b Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 7 Aug 2023 16:18:57 -0500 Subject: [PATCH 170/292] Typo fix: Storiage -> Storage (#387) --- Sources/PostgresNIO/New/PSQLError.swift | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index a13d4209..5d9e534c 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -96,7 +96,7 @@ public struct PSQLError: Error { private var backing: Backing - private mutating func copyBackingStoriageIfNecessary() { + private mutating func copyBackingStorageIfNecessary() { if !isKnownUniquelyReferenced(&self.backing) { self.backing = self.backing.copy() } @@ -106,7 +106,7 @@ public struct PSQLError: Error { public internal(set) var code: Code { get { self.backing.code } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.code = newValue } } @@ -115,7 +115,7 @@ public struct PSQLError: Error { public internal(set) var serverInfo: ServerInfo? { get { self.backing.serverInfo } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.serverInfo = newValue } } @@ -124,7 +124,7 @@ public struct PSQLError: Error { public internal(set) var underlying: Error? { get { self.backing.underlying } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.underlying = newValue } } @@ -133,7 +133,7 @@ public struct PSQLError: Error { public internal(set) var file: String? { get { self.backing.file } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.file = newValue } } @@ -142,7 +142,7 @@ public struct PSQLError: Error { public internal(set) var line: Int? { get { self.backing.line } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.line = newValue } } @@ -151,7 +151,7 @@ public struct PSQLError: Error { public internal(set) var query: PostgresQuery? { get { self.backing.query } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.query = newValue } } @@ -161,7 +161,7 @@ public struct PSQLError: Error { var backendMessage: PostgresBackendMessage? { get { self.backing.backendMessage } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.backendMessage = newValue } } @@ -171,7 +171,7 @@ public struct PSQLError: Error { var unsupportedAuthScheme: UnsupportedAuthScheme? { get { self.backing.unsupportedAuthScheme } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.unsupportedAuthScheme = newValue } } @@ -181,7 +181,7 @@ public struct PSQLError: Error { var invalidCommandTag: String? { get { self.backing.invalidCommandTag } set { - self.copyBackingStoriageIfNecessary() + self.copyBackingStorageIfNecessary() self.backing.invalidCommandTag = newValue } } From c5737e8a54c59da09bb1e699ab1c4e4b4fd99844 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 7 Aug 2023 23:35:42 -0500 Subject: [PATCH 171/292] [no ci] Fix missing docs attribute --- Sources/PostgresNIO/Docs.docc/index.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md index b4dc7e30..ebe27cd0 100644 --- a/Sources/PostgresNIO/Docs.docc/index.md +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -1,5 +1,9 @@ # ``PostgresNIO`` +@Metadata { + @TitleHeading(Package) +} + 🐘 Non-blocking, event-driven Swift client for PostgreSQL built on SwiftNIO. ## Overview From b6597f7c419a70a31b08b0dcafafe052c58b1d86 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 9 Aug 2023 23:07:39 +0200 Subject: [PATCH 172/292] Remove PrepareStatementStateMachine (#391) Preparing a statement is a substep of running an extended query. For this reason we should reuse the `ExtendedQueryStateMachine` as much as we can. This patch removes the `PrepareStatementStateMachine` and uses the `ExtendedQueryStateMachine`. As a result of this we can simplify our code in lots of other places. --- .../Connection/PostgresConnection.swift | 7 +- .../ConnectionStateMachine.swift | 165 +++---------- .../ExtendedQueryStateMachine.swift | 157 +++++++++---- .../PrepareStatementStateMachine.swift | 147 ------------ Sources/PostgresNIO/New/PSQLRowStream.swift | 36 ++- Sources/PostgresNIO/New/PSQLTask.swift | 71 +++--- .../New/PostgresChannelHandler.swift | 67 +++--- .../ConnectionStateMachineTests.swift | 6 +- .../ExtendedQueryStateMachineTests.swift | 16 +- .../PrepareStatementStateMachineTests.swift | 47 ++-- .../ConnectionAction+TestUtils.swift | 17 +- .../New/PSQLRowStreamTests.swift | 140 ++++------- .../New/PostgresRowSequenceTests.swift | 220 ++++++++---------- 13 files changed, 421 insertions(+), 675 deletions(-) delete mode 100644 Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index d6420a6e..6f849bdd 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -230,13 +230,14 @@ public final class PostgresConnection: @unchecked Sendable { func prepareStatement(_ query: String, with name: String, logger: Logger) -> EventLoopFuture { let promise = self.channel.eventLoop.makePromise(of: RowDescription?.self) - let context = PrepareStatementContext( + let context = ExtendedQueryContext( name: name, query: query, logger: logger, - promise: promise) + promise: promise + ) - self.channel.write(HandlerTask.preparedStatement(context), promise: nil) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 93312c86..0f3e96c9 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -31,7 +31,6 @@ struct ConnectionStateMachine { case readyForQuery(ConnectionContext) case extendedQuery(ExtendedQueryStateMachine, ConnectionContext) - case prepareStatement(PrepareStatementStateMachine, ConnectionContext) case closeCommand(CloseStateMachine, ConnectionContext) case error(PSQLError) @@ -89,10 +88,9 @@ struct ConnectionStateMachine { // --- general actions case sendParseDescribeBindExecuteSync(PostgresQuery) case sendBindExecuteSync(PSQLExecuteStatement) - case failQuery(ExtendedQueryContext, with: PSQLError, cleanupContext: CleanUpContext?) - case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) - case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) - + case failQuery(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + case succeedQuery(EventLoopPromise, with: QueryResult) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -101,9 +99,9 @@ struct ConnectionStateMachine { // Prepare statement actions case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(PrepareStatementContext, with: RowDescription?) - case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError, cleanupContext: CleanUpContext?) - + case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) + case failPreparedStatementCreation(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + // Close actions case sendCloseSync(CloseTarget) case succeedClose(CloseCommandContext) @@ -159,7 +157,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -214,7 +211,6 @@ struct ConnectionStateMachine { .authenticating, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand: return self.errorHappened(.uncleanShutdown) @@ -245,7 +241,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -274,7 +269,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -296,7 +290,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -322,7 +315,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -391,12 +383,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(query, connectionContext) return .wait } - case .prepareStatement(let prepareState, var connectionContext): - return self.avoidingStateMachineCoW { machine in - connectionContext.parameters[status.parameter] = status.value - machine.state = .prepareStatement(prepareState, connectionContext) - return .wait - } case .closeCommand(let closeState, var connectionContext): return self.avoidingStateMachineCoW { machine in connectionContext.parameters[status.parameter] = status.value @@ -450,15 +436,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(extendedQueryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext): - if preparedState.isComplete { - return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) - } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.errorReceived(errorMessage) - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } case .closing: // If the state machine is in state `.closing`, the connection shutdown was initiated // by the client. This means a `TERMINATE` message has already been sent and the @@ -493,13 +470,6 @@ struct ConnectionStateMachine { let action = queryState.errorHappened(error) return self.modify(with: action) } - case .prepareStatement(var prepareState, _): - if prepareState.isComplete { - return self.closeConnectionAndCleanup(error) - } else { - let action = prepareState.errorHappened(error) - return self.modify(with: action) - } case .closeCommand(var closeState, _): if closeState.isComplete { return self.closeConnectionAndCleanup(error) @@ -567,16 +537,6 @@ struct ConnectionStateMachine { self.state = .readyForQuery(connectionContext) return self.executeNextQueryFromQueue() - case .prepareStatement(let preparedStateMachine, var connectionContext): - guard preparedStateMachine.isComplete else { - return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) - } - - connectionContext.transactionState = transactionState - - self.state = .readyForQuery(connectionContext) - return self.executeNextQueryFromQueue() - case .closeCommand(let closeStateMachine, var connectionContext): guard closeStateMachine.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) @@ -597,9 +557,13 @@ struct ConnectionStateMachine { if case .quiescing = self.quiescingState { switch task { case .extendedQuery(let queryContext): - return .failQuery(queryContext, with: .connectionQuiescing, cleanupContext: nil) - case .preparedStatement(let prepareContext): - return .failPreparedStatementCreation(prepareContext, with: .connectionQuiescing, cleanupContext: nil) + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: .connectionQuiescing, cleanupContext: nil) + case .prepareStatement(_, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: .connectionQuiescing, cleanupContext: nil) + } + case .closeCommand(let closeContext): return .failClose(closeContext, with: .connectionQuiescing, cleanupContext: nil) } @@ -611,9 +575,12 @@ struct ConnectionStateMachine { case .closed: switch task { case .extendedQuery(let queryContext): - return .failQuery(queryContext, with: .connectionClosed, cleanupContext: nil) - case .preparedStatement(let prepareContext): - return .failPreparedStatementCreation(prepareContext, with: .connectionClosed, cleanupContext: nil) + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: .connectionClosed, cleanupContext: nil) + case .prepareStatement(_, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: .connectionClosed, cleanupContext: nil) + } case .closeCommand(let closeContext): return .failClose(closeContext, with: .connectionClosed, cleanupContext: nil) } @@ -633,7 +600,6 @@ struct ConnectionStateMachine { .authenticating, .authenticated, .readyForQuery, - .prepareStatement, .closeCommand, .error, .closing, @@ -676,12 +642,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(extendedQuery, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedStatement, let connectionContext): - return self.avoidingStateMachineCoW { machine in - let action = preparedStatement.readEventCaught() - machine.state = .prepareStatement(preparedStatement, connectionContext) - return machine.modify(with: action) - } case .closeCommand(var closeState, let connectionContext): return self.avoidingStateMachineCoW { machine in let action = closeState.readEventCaught() @@ -709,12 +669,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext) where !preparedState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.parseCompletedReceived() - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parseComplete)) } @@ -740,12 +694,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext) where !preparedState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.parameterDescriptionReceived(description) - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterDescription(description))) } @@ -759,12 +707,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext) where !preparedState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.rowDescriptionReceived(description) - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.rowDescription(description))) } @@ -778,12 +720,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(queryState, connectionContext) return machine.modify(with: action) } - case .prepareStatement(var preparedState, let connectionContext) where !preparedState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = preparedState.noDataReceived() - machine.state = .prepareStatement(preparedState, connectionContext) - return machine.modify(with: action) - } default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.noData)) } @@ -909,6 +845,7 @@ struct ConnectionStateMachine { preconditionFailure("Expect to fail auth") } return .closeConnectionAndCleanup(cleanupContext) + case .extendedQuery(var queryStateMachine, _): let cleanupContext = self.setErrorAndCreateCleanupContext(error) @@ -921,9 +858,10 @@ struct ConnectionStateMachine { switch queryStateMachine.errorHappened(error) { case .sendParseDescribeBindExecuteSync, + .sendParseDescribeSync, .sendBindExecuteSync, .succeedQuery, - .succeedQueryNoRowsComming, + .succeedPreparedStatementCreation, .forwardRows, .forwardStreamComplete, .wait, @@ -935,26 +873,10 @@ struct ConnectionStateMachine { return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) case .forwardStreamError(let error, let read): return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + case .failPreparedStatementCreation(let promise, with: let error): + return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) } - case .prepareStatement(var prepareStateMachine, _): - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - - if prepareStateMachine.isComplete { - // in case the prepare state machine is complete all necessary actions have already - // been forwarded to the consumer. We can close and cleanup without caring about the - // substate machine. - return .closeConnectionAndCleanup(cleanupContext) - } - - switch prepareStateMachine.errorHappened(error) { - case .sendParseDescribeSync, - .succeedPreparedStatementCreation, - .read, - .wait: - preconditionFailure("Invalid state: \(self.state)") - case .failPreparedStatementCreation(let preparedStatementContext, with: let error): - return .failPreparedStatementCreation(preparedStatementContext, with: error, cleanupContext: cleanupContext) - } + case .closeCommand(var closeStateMachine, _): let cleanupContext = self.setErrorAndCreateCleanupContext(error) @@ -974,6 +896,7 @@ struct ConnectionStateMachine { case .failClose(let closeCommandContext, with: let error): return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) } + case .error, .closing, .closed: // We might run into this case because of reentrancy. For example: After we received an // backend unexpected message, that we read of the wire, we bring this connection into @@ -1018,13 +941,6 @@ struct ConnectionStateMachine { machine.state = .extendedQuery(extendedQuery, connectionContext) return machine.modify(with: action) } - case .preparedStatement(let prepareContext): - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - var prepareStatement = PrepareStatementStateMachine(createContext: prepareContext) - let action = prepareStatement.start() - machine.state = .prepareStatement(prepareStatement, connectionContext) - return machine.modify(with: action) - } case .closeCommand(let closeContext): return self.avoidingStateMachineCoW { machine -> ConnectionAction in var closeStateMachine = CloseStateMachine(closeContext: closeContext) @@ -1153,10 +1069,8 @@ extension ConnectionStateMachine { case .failQuery(let requestContext, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) - case .succeedQuery(let requestContext, columns: let columns): - return .succeedQuery(requestContext, columns: columns) - case .succeedQueryNoRowsComming(let requestContext, let commandTag): - return .succeedQueryNoRowsComming(requestContext, commandTag: commandTag) + case .succeedQuery(let requestContext, with: let result): + return .succeedQuery(requestContext, with: result) case .forwardRows(let buffer): return .forwardRows(buffer) case .forwardStreamComplete(let buffer, let commandTag): @@ -1174,24 +1088,13 @@ extension ConnectionStateMachine { return .read case .wait: return .wait - } - } -} - -extension ConnectionStateMachine { - mutating func modify(with action: PrepareStatementStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { - switch action { - case .sendParseDescribeSync(let name, let query): + case .sendParseDescribeSync(name: let name, query: let query): return .sendParseDescribeSync(name: name, query: query) - case .succeedPreparedStatementCreation(let prepareContext, with: let rowDescription): - return .succeedPreparedStatementCreation(prepareContext, with: rowDescription) - case .failPreparedStatementCreation(let prepareContext, with: let error): + case .succeedPreparedStatementCreation(let promise, with: let rowDescription): + return .succeedPreparedStatementCreation(promise, with: rowDescription) + case .failPreparedStatementCreation(let promise, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) - return .failPreparedStatementCreation(prepareContext, with: error, cleanupContext: cleanupContext) - case .read: - return .read - case .wait: - return .wait + return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) } } } @@ -1282,8 +1185,6 @@ extension ConnectionStateMachine.State: CustomDebugStringConvertible { return ".readyForQuery(connectionContext: \(String(reflecting: connectionContext)))" case .extendedQuery(let subStateMachine, let connectionContext): return ".extendedQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" - case .prepareStatement(let subStateMachine, let connectionContext): - return ".prepareStatement(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .closeCommand(let subStateMachine, let connectionContext): return ".closeCommand(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .error(let error): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 8b46fd0b..3a84031b 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -4,7 +4,7 @@ struct ExtendedQueryStateMachine { private enum State { case initialized(ExtendedQueryContext) - case parseDescribeBindExecuteSyncSent(ExtendedQueryContext) + case messagesSent(ExtendedQueryContext) case parseCompleteReceived(ExtendedQueryContext) case parameterDescriptionReceived(ExtendedQueryContext) @@ -26,15 +26,18 @@ struct ExtendedQueryStateMachine { enum Action { case sendParseDescribeBindExecuteSync(PostgresQuery) + case sendParseDescribeSync(name: String, query: String) case sendBindExecuteSync(PSQLExecuteStatement) // --- general actions - case failQuery(ExtendedQueryContext, with: PSQLError) - case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) - case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) + case failQuery(EventLoopPromise, with: PSQLError) + case succeedQuery(EventLoopPromise, with: QueryResult) case evaluateErrorAtConnectionLevel(PSQLError) + case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) + case failPreparedStatementCreation(EventLoopPromise, with: PSQLError) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -59,13 +62,13 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed(let query): + case .unnamed(let query, _): return self.avoidingStateMachineCoW { state -> Action in - state = .parseDescribeBindExecuteSyncSent(queryContext) + state = .messagesSent(queryContext) return .sendParseDescribeBindExecuteSync(query) } - case .preparedStatement(let prepared): + case .executeStatement(let prepared, _): return self.avoidingStateMachineCoW { state -> Action in switch prepared.rowDescription { case .some(let rowDescription): @@ -75,6 +78,12 @@ struct ExtendedQueryStateMachine { } return .sendBindExecuteSync(prepared) } + + case .prepareStatement(let name, let query, _): + return self.avoidingStateMachineCoW { state -> Action in + state = .messagesSent(queryContext) + return .sendParseDescribeSync(name: name, query: query) + } } } @@ -83,7 +92,7 @@ struct ExtendedQueryStateMachine { case .initialized: preconditionFailure("Start must be called immediatly after the query was created") - case .parseDescribeBindExecuteSyncSent(let queryContext), + case .messagesSent(let queryContext), .parseCompleteReceived(let queryContext), .parameterDescriptionReceived(let queryContext), .rowDescriptionReceived(let queryContext, _), @@ -94,7 +103,13 @@ struct ExtendedQueryStateMachine { } self.isCancelled = true - return .failQuery(queryContext, with: .queryCancelled) + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: .queryCancelled) + + case .prepareStatement(_, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) + } case .streaming(let columns, var streamStateMachine): precondition(!self.isCancelled) @@ -117,7 +132,7 @@ struct ExtendedQueryStateMachine { } mutating func parseCompletedReceived() -> Action { - guard case .parseDescribeBindExecuteSyncSent(let queryContext) = self.state else { + guard case .messagesSent(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.parseComplete)) } @@ -143,9 +158,18 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(.unexpectedBackendMessage(.noData)) } - return self.avoidingStateMachineCoW { state -> Action in - state = .noDataMessageReceived(queryContext) - return .wait + switch queryContext.query { + case .unnamed, .executeStatement: + return self.avoidingStateMachineCoW { state -> Action in + state = .noDataMessageReceived(queryContext) + return .wait + } + + case .prepareStatement(_, _, let promise): + return self.avoidingStateMachineCoW { state -> Action in + state = .noDataMessageReceived(queryContext) + return .succeedPreparedStatementCreation(promise, with: nil) + } } } @@ -153,40 +177,56 @@ struct ExtendedQueryStateMachine { guard case .parameterDescriptionReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) } - - return self.avoidingStateMachineCoW { state -> Action in - // In Postgres extended queries we receive the `rowDescription` before we send the - // `Bind` message. Well actually it's vice versa, but this is only true since we do - // pipelining during a query. - // - // In the actual protocol description we receive a rowDescription before the Bind - - // In Postgres extended queries we always request the response rows to be returned in - // `.binary` format. - let columns = rowDescription.columns.map { column -> RowDescription.Column in - var column = column - column.format = .binary - return column - } + + // In Postgres extended queries we receive the `rowDescription` before we send the + // `Bind` message. Well actually it's vice versa, but this is only true since we do + // pipelining during a query. + // + // In the actual protocol description we receive a rowDescription before the Bind + + // In Postgres extended queries we always request the response rows to be returned in + // `.binary` format. + let columns = rowDescription.columns.map { column -> RowDescription.Column in + var column = column + column.format = .binary + return column + } + + self.avoidingStateMachineCoW { state in state = .rowDescriptionReceived(queryContext, columns) + } + + switch queryContext.query { + case .unnamed, .executeStatement: return .wait + + case .prepareStatement(_, _, let eventLoopPromise): + return .succeedPreparedStatementCreation(eventLoopPromise, with: rowDescription) } } mutating func bindCompleteReceived() -> Action { switch self.state { - case .rowDescriptionReceived(let context, let columns): - return self.avoidingStateMachineCoW { state -> Action in - state = .streaming(columns, .init()) - return .succeedQuery(context, columns: columns) + case .rowDescriptionReceived(let queryContext, let columns): + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .streaming(columns, .init()) + let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement: + return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) } + case .noDataMessageReceived(let queryContext): return self.avoidingStateMachineCoW { state -> Action in state = .bindCompleteReceived(queryContext) return .wait } case .initialized, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, @@ -224,7 +264,7 @@ struct ExtendedQueryStateMachine { return .wait case .initialized, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -241,9 +281,16 @@ struct ExtendedQueryStateMachine { mutating func commandCompletedReceived(_ commandTag: String) -> Action { switch self.state { case .bindCompleteReceived(let context): - return self.avoidingStateMachineCoW { state -> Action in - state = .commandComplete(commandTag: commandTag) - return .succeedQueryNoRowsComming(context, commandTag: commandTag) + switch context.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + let result = QueryResult(value: .noRows(commandTag), logger: context.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement: + preconditionFailure("Invalid state: \(self.state)") } case .streaming(_, var demandStateMachine): @@ -258,7 +305,7 @@ struct ExtendedQueryStateMachine { return .wait case .initialized, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -280,7 +327,7 @@ struct ExtendedQueryStateMachine { switch self.state { case .initialized: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) - case .parseDescribeBindExecuteSyncSent, + case .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived: @@ -331,7 +378,7 @@ struct ExtendedQueryStateMachine { return .wait case .initialized, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -354,7 +401,7 @@ struct ExtendedQueryStateMachine { .commandComplete, .drain, .error, - .parseDescribeBindExecuteSyncSent, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -381,7 +428,7 @@ struct ExtendedQueryStateMachine { mutating func readEventCaught() -> Action { switch self.state { - case .parseDescribeBindExecuteSyncSent, + case .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, @@ -417,7 +464,7 @@ struct ExtendedQueryStateMachine { private mutating func setAndFireError(_ error: PSQLError) -> Action { switch self.state { case .initialized(let context), - .parseDescribeBindExecuteSyncSent(let context), + .messagesSent(let context), .parseCompleteReceived(let context), .parameterDescriptionReceived(let context), .rowDescriptionReceived(let context, _), @@ -427,7 +474,12 @@ struct ExtendedQueryStateMachine { if self.isCancelled { return .evaluateErrorAtConnectionLevel(error) } else { - return .failQuery(context, with: error) + switch context.query { + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): + return .failQuery(eventLoopPromise, with: error) + case .prepareStatement(_, _, let eventLoopPromise): + return .failPreparedStatementCreation(eventLoopPromise, with: error) + } } case .drain: @@ -455,11 +507,22 @@ struct ExtendedQueryStateMachine { var isComplete: Bool { switch self.state { - case .commandComplete, - .error: + case .commandComplete, .error: return true - default: + + case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _): + switch context.query { + case .prepareStatement: + return true + case .unnamed, .executeStatement: + return false + } + + case .initialized, .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, .streaming, .drain: return false + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") } } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift deleted file mode 100644 index 5b65fc90..00000000 --- a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift +++ /dev/null @@ -1,147 +0,0 @@ - -struct PrepareStatementStateMachine { - - enum State { - case initialized(PrepareStatementContext) - case parseDescribeSent(PrepareStatementContext) - - case parseCompleteReceived(PrepareStatementContext) - case parameterDescriptionReceived(PrepareStatementContext) - case rowDescriptionReceived - case noDataMessageReceived - - case error(PSQLError) - } - - enum Action { - case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(PrepareStatementContext, with: RowDescription?) - case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError) - - case read - case wait - } - - var state: State - - init(createContext: PrepareStatementContext) { - self.state = .initialized(createContext) - } - - #if DEBUG - /// for testing purposes only - init(_ state: State) { - self.state = state - } - #endif - - mutating func start() -> Action { - guard case .initialized(let createContext) = self.state else { - preconditionFailure("Start must only be called after the query has been initialized") - } - - self.state = .parseDescribeSent(createContext) - - return .sendParseDescribeSync(name: createContext.name, query: createContext.query) - } - - mutating func parseCompletedReceived() -> Action { - guard case .parseDescribeSent(let createContext) = self.state else { - return self.setAndFireError(.unexpectedBackendMessage(.parseComplete)) - } - - self.state = .parseCompleteReceived(createContext) - return .wait - } - - mutating func parameterDescriptionReceived(_ parameterDescription: PostgresBackendMessage.ParameterDescription) -> Action { - guard case .parseCompleteReceived(let createContext) = self.state else { - return self.setAndFireError(.unexpectedBackendMessage(.parameterDescription(parameterDescription))) - } - - self.state = .parameterDescriptionReceived(createContext) - return .wait - } - - mutating func noDataReceived() -> Action { - guard case .parameterDescriptionReceived(let queryContext) = self.state else { - return self.setAndFireError(.unexpectedBackendMessage(.noData)) - } - - self.state = .noDataMessageReceived - return .succeedPreparedStatementCreation(queryContext, with: nil) - } - - mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { - guard case .parameterDescriptionReceived(let queryContext) = self.state else { - return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) - } - - self.state = .rowDescriptionReceived - return .succeedPreparedStatementCreation(queryContext, with: rowDescription) - } - - mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { - let error = PSQLError.server(errorMessage) - switch self.state { - case .initialized: - return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) - - case .parseDescribeSent, - .parseCompleteReceived, - .parameterDescriptionReceived: - return self.setAndFireError(error) - - case .rowDescriptionReceived, - .noDataMessageReceived, - .error: - preconditionFailure(""" - This state must not be reached. If the prepared statement `.isComplete`, the - ConnectionStateMachine must not send any further events to the substate machine. - """) - } - } - - mutating func errorHappened(_ error: PSQLError) -> Action { - return self.setAndFireError(error) - } - - private mutating func setAndFireError(_ error: PSQLError) -> Action { - switch self.state { - case .initialized(let context), - .parseDescribeSent(let context), - .parseCompleteReceived(let context), - .parameterDescriptionReceived(let context): - self.state = .error(error) - return .failPreparedStatementCreation(context, with: error) - case .rowDescriptionReceived, - .noDataMessageReceived, - .error: - preconditionFailure(""" - This state must not be reached. If the prepared statement `.isComplete`, the - ConnectionStateMachine must not send any further events to the substate machine. - """) - } - } - - // MARK: Channel actions - - mutating func readEventCaught() -> Action { - return .read - } - - var isComplete: Bool { - switch self.state { - case .rowDescriptionReceived, - .noDataMessageReceived, - .error: - return true - case .initialized, - .parseDescribeSent, - .parseCompleteReceived, - .parameterDescriptionReceived: - return false - } - } - -} diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 4c842275..b008d185 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -1,12 +1,23 @@ import NIOCore import Logging +struct QueryResult { + enum Value: Equatable { + case noRows(String) + case rowDescription([RowDescription.Column]) + } + + var value: Value + + var logger: Logger +} + // Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop. final class PSQLRowStream: @unchecked Sendable { private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source - enum RowSource { - case stream(PSQLRowsDataSource) + enum Source { + case stream([RowDescription.Column], PSQLRowsDataSource) case noRows(Result) } @@ -31,27 +42,28 @@ final class PSQLRowStream: @unchecked Sendable { private let lookupTable: [String: Int] private var downstreamState: DownstreamState - init(rowDescription: [RowDescription.Column], - queryContext: ExtendedQueryContext, - eventLoop: EventLoop, - rowSource: RowSource) - { + init( + source: Source, + eventLoop: EventLoop, + logger: Logger + ) { let bufferState: BufferState - switch rowSource { - case .stream(let dataSource): + switch source { + case .stream(let rowDescription, let dataSource): + self.rowDescription = rowDescription bufferState = .streaming(buffer: .init(), dataSource: dataSource) case .noRows(.success(let commandTag)): + self.rowDescription = [] bufferState = .finished(buffer: .init(), commandTag: commandTag) case .noRows(.failure(let error)): + self.rowDescription = [] bufferState = .failure(error) } self.downstreamState = .waitingForConsumer(bufferState) self.eventLoop = eventLoop - self.logger = queryContext.logger - - self.rowDescription = rowDescription + self.logger = logger var lookup = [String: Int]() lookup.reserveCapacity(rowDescription.count) diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 26312c0c..f5de6561 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -3,7 +3,6 @@ import NIOCore enum HandlerTask { case extendedQuery(ExtendedQueryContext) - case preparedStatement(PrepareStatementContext) case closeCommand(CloseCommandContext) case startListening(NotificationListener) case cancelListening(String, Int) @@ -11,16 +10,19 @@ enum HandlerTask { enum PSQLTask { case extendedQuery(ExtendedQueryContext) - case preparedStatement(PrepareStatementContext) case closeCommand(CloseCommandContext) func failWithError(_ error: PSQLError) { switch self { case .extendedQuery(let extendedQueryContext): - extendedQueryContext.promise.fail(error) - - case .preparedStatement(let createPreparedStatementContext): - createPreparedStatementContext.promise.fail(error) + switch extendedQueryContext.query { + case .unnamed(_, let eventLoopPromise): + eventLoopPromise.fail(error) + case .executeStatement(_, let eventLoopPromise): + eventLoopPromise.fail(error) + case .prepareStatement(_, _, let eventLoopPromise): + eventLoopPromise.fail(error) + } case .closeCommand(let closeCommandContext): closeCommandContext.promise.fail(error) @@ -30,49 +32,40 @@ enum PSQLTask { final class ExtendedQueryContext { enum Query { - case unnamed(PostgresQuery) - case preparedStatement(PSQLExecuteStatement) + case unnamed(PostgresQuery, EventLoopPromise) + case executeStatement(PSQLExecuteStatement, EventLoopPromise) + case prepareStatement(name: String, query: String, EventLoopPromise) } let query: Query let logger: Logger - - let promise: EventLoopPromise - init(query: PostgresQuery, - logger: Logger, - promise: EventLoopPromise) - { - self.query = .unnamed(query) + init( + query: PostgresQuery, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .unnamed(query, promise) self.logger = logger - self.promise = promise } - init(executeStatement: PSQLExecuteStatement, - logger: Logger, - promise: EventLoopPromise) - { - self.query = .preparedStatement(executeStatement) + init( + executeStatement: PSQLExecuteStatement, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .executeStatement(executeStatement, promise) self.logger = logger - self.promise = promise } -} -final class PrepareStatementContext { - let name: String - let query: String - let logger: Logger - let promise: EventLoopPromise - - init(name: String, - query: String, - logger: Logger, - promise: EventLoopPromise) - { - self.name = name - self.query = query + init( + name: String, + query: String, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = .prepareStatement(name: name, query: query, promise) self.logger = logger - self.promise = promise } } @@ -83,8 +76,8 @@ final class CloseCommandContext { init(target: CloseTarget, logger: Logger, - promise: EventLoopPromise) - { + promise: EventLoopPromise + ) { self.target = target self.logger = logger self.promise = promise diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 32c35927..abfa5aeb 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -206,8 +206,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { psqlTask = .closeCommand(command) case .extendedQuery(let query): psqlTask = .extendedQuery(query) - case .preparedStatement(let statement): - psqlTask = .preparedStatement(statement) case .startListening(let listener): switch self.listenState.startListening(listener) { @@ -326,12 +324,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) case .sendParseDescribeBindExecuteSync(let query): self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) - case .succeedQuery(let queryContext, columns: let columns): - self.succeedQueryWithRowStream(queryContext, columns: columns, context: context) - case .succeedQueryNoRowsComming(let queryContext, let commandTag): - self.succeedQueryWithoutRowStream(queryContext, commandTag: commandTag, context: context) - case .failQuery(let queryContext, with: let error, let cleanupContext): - queryContext.promise.fail(error) + case .succeedQuery(let promise, with: let result): + self.succeedQuery(promise, result: result, context: context) + case .failQuery(let promise, with: let error, let cleanupContext): + promise.fail(error) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } @@ -383,10 +379,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } context.close(mode: .all, promise: promise) - case .succeedPreparedStatementCreation(let preparedContext, with: let rowDescription): - preparedContext.promise.succeed(rowDescription) - case .failPreparedStatementCreation(let preparedContext, with: let error, let cleanupContext): - preparedContext.promise.fail(error) + case .succeedPreparedStatementCreation(let promise, with: let rowDescription): + promise.succeed(rowDescription) + case .failPreparedStatementCreation(let promise, with: let error, let cleanupContext): + promise.fail(error) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } @@ -510,33 +506,30 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } - private func succeedQueryWithRowStream( - _ queryContext: ExtendedQueryContext, - columns: [RowDescription.Column], + private func succeedQuery( + _ promise: EventLoopPromise, + result: QueryResult, context: ChannelHandlerContext ) { - let rows = PSQLRowStream( - rowDescription: columns, - queryContext: queryContext, - eventLoop: context.channel.eventLoop, - rowSource: .stream(self)) - - self.rowStream = rows - queryContext.promise.succeed(rows) - } - - private func succeedQueryWithoutRowStream( - _ queryContext: ExtendedQueryContext, - commandTag: String, - context: ChannelHandlerContext - ) { - let rows = PSQLRowStream( - rowDescription: [], - queryContext: queryContext, - eventLoop: context.channel.eventLoop, - rowSource: .noRows(.success(commandTag)) - ) - queryContext.promise.succeed(rows) + let rows: PSQLRowStream + switch result.value { + case .rowDescription(let columns): + rows = PSQLRowStream( + source: .stream(columns, self), + eventLoop: context.channel.eventLoop, + logger: result.logger + ) + self.rowStream = rows + + case .noRows(let commandTag): + rows = PSQLRowStream( + source: .noRows(.success(commandTag)), + eventLoop: context.channel.eventLoop, + logger: result.logger + ) + } + + promise.succeed(rows) } private func closeConnectionAndCleanup( diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index d5d4ecb1..5fd3bc20 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -180,9 +180,9 @@ class ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.errorReceived(.init(fields: fields)), .closeConnectionAndCleanup(.init(action: .close, tasks: [.extendedQuery(extendedQueryContext)], error: .server(.init(fields: fields)), closePromise: nil))) - XCTAssertNil(extendedQueryContext.promise.futureResult._value) - + XCTAssertNil(queryPromise.futureResult._value) + // make sure we don't crash - extendedQueryContext.promise.fail(PSQLError.server(.init(fields: fields))) + queryPromise.fail(PSQLError.server(.init(fields: fields))) } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index eac46e5f..40e32468 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -20,7 +20,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), .wait) XCTAssertEqual(state.bindCompleteReceived(), .wait) - XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQueryNoRowsComming(queryContext, commandTag: "DELETE 1")) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows("DELETE 1"), logger: logger))) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } @@ -49,7 +49,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) let row1: DataRow = [ByteBuffer(string: "test1")] XCTAssertEqual(state.dataRowReceived(row1), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) @@ -93,7 +93,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let psqlError = PSQLError.unexpectedBackendMessage(.authentication(.ok)) XCTAssertEqual(state.authenticationMessageReceived(.ok), - .failQuery(queryContext, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) + .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) } func testExtendedQueryIsCancelledImmediatly() { @@ -121,7 +121,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) @@ -165,7 +165,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) let row1: DataRow = [ByteBuffer(string: "test1")] XCTAssertEqual(state.dataRowReceived(row1), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) @@ -207,7 +207,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) + XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) let dataRows1: [DataRow] = [ [ByteBuffer(string: "test1")], [ByteBuffer(string: "test2")], @@ -251,7 +251,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) XCTAssertEqual( - state.errorReceived(serverError), .failQuery(queryContext, with: .server(serverError), cleanupContext: .none) + state.errorReceived(serverError), .failQuery(promise, with: .server(serverError), cleanupContext: .none) ) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) @@ -269,7 +269,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) - XCTAssertEqual(state.cancelQueryStream(), .failQuery(queryContext, with: .queryCancelled, cleanupContext: .none)) + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) XCTAssertEqual(state.errorReceived(serverError), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 6cff280e..6a08afeb 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -3,7 +3,6 @@ import NIOEmbedded @testable import PostgresNIO class PrepareStatementStateMachineTests: XCTestCase { - func testCreatePreparedStatementReturningRowDescription() { var state = ConnectionStateMachine.readyForQuery() @@ -12,10 +11,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"SELECT id FROM users WHERE id = $1 "# - let prepareStatementContext = PrepareStatementContext( - name: name, query: query, logger: .psqlTest, promise: promise) - - XCTAssertEqual(state.enqueue(task: .preparedStatement(prepareStatementContext)), + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), .sendParseDescribeSync(name: name, query: query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -25,7 +25,7 @@ class PrepareStatementStateMachineTests: XCTestCase { ] XCTAssertEqual(state.rowDescriptionReceived(.init(columns: columns)), - .succeedPreparedStatementCreation(prepareStatementContext, with: .init(columns: columns))) + .succeedPreparedStatementCreation(promise, with: .init(columns: columns))) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } @@ -37,25 +37,42 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# - let prepareStatementContext = PrepareStatementContext( - name: name, query: query, logger: .psqlTest, promise: promise) - - XCTAssertEqual(state.enqueue(task: .preparedStatement(prepareStatementContext)), + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), .sendParseDescribeSync(name: name, query: query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), - .succeedPreparedStatementCreation(prepareStatementContext, with: nil)) + .succeedPreparedStatementCreation(promise, with: nil)) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } func testErrorReceivedAfter() { - let connectionContext = ConnectionStateMachine.createConnectionContext() - var state = ConnectionStateMachine(.prepareStatement(.init(.noDataMessageReceived), connectionContext)) - + var state = ConnectionStateMachine.readyForQuery() + + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + + let name = "haha" + let query = #"DELETE FROM users WHERE id = $1 "# + let prepareStatementContext = ExtendedQueryContext( + name: name, query: query, logger: .psqlTest, promise: promise + ) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), + .sendParseDescribeSync(name: name, query: query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + + XCTAssertEqual(state.noDataReceived(), + .succeedPreparedStatementCreation(promise, with: nil)) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + XCTAssertEqual(state.authenticationMessageReceived(.ok), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(.ok)), closePromise: nil))) } - } diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index 72420798..febeee37 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -25,13 +25,10 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { return lquery == rquery case (.fireEventReadyForQuery, .fireEventReadyForQuery): return true - - case (.succeedQueryNoRowsComming(let lhsContext, let lhsCommandTag), .succeedQueryNoRowsComming(let rhsContext, let rhsCommandTag)): - return lhsContext === rhsContext && lhsCommandTag == rhsCommandTag - case (.succeedQuery(let lhsContext, let lhsRowDescription), .succeedQuery(let rhsContext, let rhsRowDescription)): - return lhsContext === rhsContext && lhsRowDescription == rhsRowDescription - case (.failQuery(let lhsContext, let lhsError, let lhsCleanupContext), .failQuery(let rhsContext, let rhsError, let rhsCleanupContext)): - return lhsContext === rhsContext && lhsError == rhsError && lhsCleanupContext == rhsCleanupContext + case (.succeedQuery(let lhsPromise, let lhsResult), .succeedQuery(let rhsPromise, let rhsResult)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsResult.value == rhsResult.value + case (.failQuery(let lhsPromise, let lhsError, let lhsCleanupContext), .failQuery(let rhsPromise, let rhsError, let rhsCleanupContext)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsError == rhsError && lhsCleanupContext == rhsCleanupContext case (.forwardRows(let lhsRows), .forwardRows(let rhsRows)): return lhsRows == rhsRows case (.forwardStreamComplete(let lhsBuffer, let lhsCommandTag), .forwardStreamComplete(let rhsBuffer, let rhsCommandTag)): @@ -40,8 +37,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { return lhsError == rhsError && lhsRead == rhsRead && lhsCleanupContext == rhsCleanupContext case (.sendParseDescribeSync(let lhsName, let lhsQuery), .sendParseDescribeSync(let rhsName, let rhsQuery)): return lhsName == rhsName && lhsQuery == rhsQuery - case (.succeedPreparedStatementCreation(let lhsContext, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsContext, let rhsRowDescription)): - return lhsContext === rhsContext && lhsRowDescription == rhsRowDescription + case (.succeedPreparedStatementCreation(let lhsPromise, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsPromise, let rhsRowDescription)): + return lhsPromise.futureResult === rhsPromise.futureResult && lhsRowDescription == rhsRowDescription case (.fireChannelInactive, .fireChannelInactive): return true default: @@ -110,8 +107,6 @@ extension PSQLTask: Equatable { switch (lhs, rhs) { case (.extendedQuery(let lhs), .extendedQuery(let rhs)): return lhs === rhs - case (.preparedStatement(let lhs), .preparedStatement(let rhs)): - return lhs === rhs case (.closeCommand(let lhs), .closeCommand(let rhs)): return lhs === rhs default: diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index f27ff060..1af35fac 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -5,44 +5,27 @@ import XCTest import NIOCore import NIOEmbedded -class PSQLRowStreamTests: XCTestCase { +final class PSQLRowStreamTests: XCTestCase { + let logger = Logger(label: "PSQLRowStreamTests") + let eventLoop = EmbeddedEventLoop() + func testEmptyStream() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "INSERT INTO foo bar;", logger: logger, promise: promise - ) - let stream = PSQLRowStream( - rowDescription: [], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .noRows(.success("INSERT 0 1")) + source: .noRows(.success("INSERT 0 1")), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(try stream.all().wait(), []) XCTAssertEqual(stream.commandTag, "INSERT 0 1") } func testFailedStream() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let stream = PSQLRowStream( - rowDescription: [], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .noRows(.failure(PSQLError.connectionClosed)) + source: .noRows(.failure(PSQLError.connectionClosed)), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertThrowsError(try stream.all().wait()) { XCTAssertEqual($0 as? PSQLError, .connectionClosed) @@ -50,24 +33,15 @@ class PSQLRowStreamTests: XCTestCase { } func testGetArrayAfterStreamHasFinished() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) @@ -89,22 +63,15 @@ class PSQLRowStreamTests: XCTestCase { } func testGetArrayBeforeStreamHasFinished() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise) let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) @@ -139,24 +106,15 @@ class PSQLRowStreamTests: XCTestCase { } func testOnRowAfterStreamHasFinished() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) @@ -183,24 +141,15 @@ class PSQLRowStreamTests: XCTestCase { } func testOnRowThrowsErrorOnInitialBatch() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) @@ -232,24 +181,15 @@ class PSQLRowStreamTests: XCTestCase { func testOnRowBeforeStreamHasFinished() { - let logger = Logger(label: "test") - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - - let queryContext = ExtendedQueryContext( - query: "SELECT * FROM test;", logger: logger, promise: promise - ) - let dataSource = CountingDataSource() let stream = PSQLRowStream( - rowDescription: [ - self.makeColumnDescription(name: "foo", dataType: .text, format: .binary) - ], - queryContext: queryContext, - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [self.makeColumnDescription(name: "foo", dataType: .text, format: .binary)], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertEqual(dataSource.hitCancel, 0) diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index e1fdad11..fc589c0b 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -7,21 +7,21 @@ import NIOCore import Logging final class PostgresRowSequenceTests: XCTestCase { + let logger = Logger(label: "PSQLRowStreamTests") + let eventLoop = EmbeddedEventLoop() func testBackpressureWorks() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() XCTAssertEqual(dataSource.requestCount, 0) @@ -38,20 +38,19 @@ final class PostgresRowSequenceTests: XCTestCase { XCTAssertNil(empty) } + func testCancellationWorksWhileIterating() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() XCTAssertEqual(dataSource.requestCount, 0) @@ -72,19 +71,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testCancellationWorksBeforeIterating() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() XCTAssertEqual(dataSource.requestCount, 0) @@ -99,19 +96,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testDroppingTheSequenceCancelsTheSource() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) var rowSequence: PostgresRowSequence? = stream.asyncSequence() rowSequence = nil @@ -121,19 +116,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testStreamBasedOnCompletedQuery() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } @@ -150,19 +143,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testStreamIfInitializedWithAllData() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) @@ -180,19 +171,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testStreamIfInitializedWithError() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) stream.receive(completion: .failure(PSQLError.connectionClosed)) @@ -210,19 +199,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testSucceedingRowContinuationsWorks() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() var rowIterator = rowSequence.makeAsyncIterator() @@ -244,19 +231,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testFailingRowContinuationsWorks() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let rowSequence = stream.asyncSequence() var rowIterator = rowSequence.makeAsyncIterator() @@ -282,19 +267,17 @@ final class PostgresRowSequenceTests: XCTestCase { } func testAdaptiveRowBufferShrinksAndGrows() async throws { - let eventLoop = EmbeddedEventLoop() - let promise = eventLoop.makePromise(of: PSQLRowStream.self) - let logger = Logger(label: "test") let dataSource = MockRowDataSource() let stream = PSQLRowStream( - rowDescription: [ - .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) - ], - queryContext: .init(query: "SELECT * FROM foo", logger: logger, promise: promise), - eventLoop: eventLoop, - rowSource: .stream(dataSource) + source: .stream( + [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ], + dataSource + ), + eventLoop: self.eventLoop, + logger: self.logger ) - promise.succeed(stream) let initialDataRows: [DataRow] = (0.. Date: Wed, 9 Aug 2023 23:11:53 +0200 Subject: [PATCH 173/292] PostgresNotificationSequence is not Sendable in 5.6 (#392) `AsyncThrowingStream` is not `Sendable` in Swift 5.6. Because of this `PostgresNotificationSequence` can not be `Sendable` in 5.6. --- Sources/PostgresNIO/New/PostgresNotificationSequence.swift | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift index 735c01b0..55fb0670 100644 --- a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -3,7 +3,7 @@ public struct PostgresNotification: Sendable { public let payload: String } -public struct PostgresNotificationSequence: AsyncSequence, Sendable { +public struct PostgresNotificationSequence: AsyncSequence { public typealias Element = PostgresNotification let base: AsyncThrowingStream @@ -20,3 +20,8 @@ public struct PostgresNotificationSequence: AsyncSequence, Sendable { } } } + +#if swift(>=5.7) +// AsyncThrowingStream is marked as Sendable in Swift 5.6 +extension PostgresNotificationSequence: Sendable {} +#endif From a5758b0c1bcbf3f0a27335d60813509a93027dc5 Mon Sep 17 00:00:00 2001 From: Thomas Krajacic Date: Wed, 9 Aug 2023 23:17:01 +0200 Subject: [PATCH 174/292] Use EventLoop provided by SwiftNIO's MultiThreadedEventLoopGroup.singleton (#389) Co-authored-by: Fabian Fett --- Package.swift | 4 ++-- README.md | 20 +----------------- .../Connection/PostgresConnection.swift | 21 +++++++++++++++++-- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/Package.swift b/Package.swift index c1cb4bda..a45925ed 100644 --- a/Package.swift +++ b/Package.swift @@ -14,8 +14,8 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.52.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.16.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.58.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.18.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.23.1"), .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), diff --git a/README.md b/README.md index 51e0b8c5..441a41e3 100644 --- a/README.md +++ b/README.md @@ -67,19 +67,7 @@ let config = PostgresConnection.Configuration( ) ``` -A connection must be created on a SwiftNIO `EventLoop`. In most server use cases, an -`EventLoopGroup` is created at app startup and closed during app shutdown. - -```swift -import NIOPosix - -let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - -// Much later -try await eventLoopGroup.shutdownGracefully() -``` - -A [`Logger`] is also required. +To create a connection we need a [`Logger`], that is used to log connection background events. ```swift import Logging @@ -91,10 +79,8 @@ Now we can put it together: ```swift import PostgresNIO -import NIOPosix import Logging -let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) let logger = Logger(label: "postgres-logger") let config = PostgresConnection.Configuration( @@ -107,7 +93,6 @@ let config = PostgresConnection.Configuration( ) let connection = try await PostgresConnection.connect( - on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger @@ -115,9 +100,6 @@ let connection = try await PostgresConnection.connect( // Close your connection once done try await connection.close() - -// Shutdown the EventLoopGroup, once all connections are closed. -try await eventLoopGroup.shutdownGracefully() ``` #### Querying diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 6f849bdd..f8a9709e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -360,13 +360,13 @@ extension PostgresConnection { /// Creates a new connection to a Postgres server. /// /// - Parameters: - /// - eventLoop: The `EventLoop` the request shall be created on + /// - eventLoop: The `EventLoop` the connection shall be created on. /// - configuration: A ``Configuration`` that shall be used for the connection /// - connectionID: An `Int` id, used for metadata logging /// - logger: A logger to log background events into /// - Returns: An established ``PostgresConnection`` asynchronously that can be used to run queries. public static func connect( - on eventLoop: EventLoop, + on eventLoop: EventLoop = PostgresConnection.defaultEventLoopGroup.any(), configuration: PostgresConnection.Configuration, id connectionID: ID, logger: Logger @@ -661,3 +661,20 @@ extension EventLoopFuture { } } } + +extension PostgresConnection { + /// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform. + /// + /// This will select the concrete `EventLoopGroup` depending which platform this is running on. + public static var defaultEventLoopGroup: EventLoopGroup { +#if canImport(Network) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + return NIOTSEventLoopGroup.singleton + } else { + return MultiThreadedEventLoopGroup.singleton + } +#else + return MultiThreadedEventLoopGroup.singleton +#endif + } +} From 52d5636edd2da896d1669dfd7fd4f83de94686c4 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 10 Aug 2023 08:03:36 +0200 Subject: [PATCH 175/292] `close()` closes immediately; Add new `closeGracefully()` (#383) Fixes #370. This patch changes the behavior of `PostgresConnection.close()`. Currently `close()` terminates the connection only after all queued queries have been successfully processed by the server. This however leads to an unwanted dependency on the Postgres server to close a connection. If a server stops responding, the client is currently unable to close its connection. Because of this, this patch changes the behavior of `close()`. `close()` now terminates a connection immediately and fails all running or queued queries. To allow users to continue to use the existing behavior we introduce a `closeGracefully()` that now has the same behavior as close had previously. Since we never documented the old close behavior and we consider it dangerous in certain situations we are fine with changing the behavior without tagging a major version. --- .../Connection/PostgresConnection.swift | 11 ++ .../ConnectionStateMachine.swift | 164 ++++++++++-------- .../ListenStateMachine.swift | 11 +- Sources/PostgresNIO/New/PSQLError.swift | 54 ++++-- .../PostgresNIO/New/PSQLEventsHandler.swift | 2 + .../New/PostgresChannelHandler.swift | 7 +- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 6 +- .../ConnectionStateMachineTests.swift | 6 +- .../New/PSQLRowStreamTests.swift | 4 +- .../New/PostgresChannelHandlerTests.swift | 7 +- .../New/PostgresConnectionTests.swift | 92 ++++++++++ .../New/PostgresRowSequenceTests.swift | 8 +- 12 files changed, 263 insertions(+), 109 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index f8a9709e..7ac8ec57 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -384,6 +384,17 @@ extension PostgresConnection { try await self.close().get() } + /// Closes the connection to the server, _after all queries_ that have been created on this connection have been run. + public func closeGracefully() async throws { + try await withTaskCancellationHandler { () async throws -> () in + let promise = self.eventLoop.makePromise(of: Void.self) + self.channel.triggerUserOutboundEvent(PSQLOutgoingEvent.gracefulShutdown, promise: promise) + return try await promise.futureResult.get() + } onCancel: { + _ = self.close() + } + } + /// Run a query on the Postgres server the connection is connected to. /// /// - Parameters: diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 0f3e96c9..bbfa0faa 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -32,11 +32,10 @@ struct ConnectionStateMachine { case readyForQuery(ConnectionContext) case extendedQuery(ExtendedQueryStateMachine, ConnectionContext) case closeCommand(CloseStateMachine, ConnectionContext) - - case error(PSQLError) - case closing - case closed - + + case closing(PSQLError?) + case closed(clientInitiated: Bool, error: PSQLError?) + case modifying } @@ -158,7 +157,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed, .modifying: @@ -170,9 +168,9 @@ struct ConnectionStateMachine { self.startAuthentication(authContext) } - mutating func close(_ promise: EventLoopPromise?) -> ConnectionAction { + mutating func gracefulClose(_ promise: EventLoopPromise?) -> ConnectionAction { switch self.state { - case .closing, .closed, .error: + case .closing, .closed: // we are already closed, but sometimes an upstream handler might want to close the // connection, though it has already been closed by the remote. Typical race condition. return .closeConnection(promise) @@ -180,7 +178,7 @@ struct ConnectionStateMachine { precondition(self.taskQueue.isEmpty, """ The state should only be .readyForQuery if there are no more tasks in the queue """) - self.state = .closing + self.state = .closing(nil) return .closeConnection(promise) default: switch self.quiescingState { @@ -194,7 +192,11 @@ struct ConnectionStateMachine { return .wait } } - + + mutating func close(promise: EventLoopPromise?) -> ConnectionAction { + return self.closeConnectionAndCleanup(.clientClosedConnection(underlying: nil), closePromise: promise) + } + mutating func closed() -> ConnectionAction { switch self.state { case .initialized: @@ -214,8 +216,8 @@ struct ConnectionStateMachine { .closeCommand: return self.errorHappened(.uncleanShutdown) - case .error, .closing: - self.state = .closed + case .closing(let error): + self.state = .closed(clientInitiated: true, error: error) self.quiescingState = .notQuiescing return .fireChannelInactive @@ -242,7 +244,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) @@ -270,7 +271,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) @@ -291,7 +291,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed: preconditionFailure("Can only add a ssl handler after negotiation: \(self.state)") @@ -316,7 +315,6 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand, - .error, .closing, .closed: preconditionFailure("Can only establish a ssl connection after adding a ssl handler: \(self.state)") @@ -363,8 +361,7 @@ struct ConnectionStateMachine { .waitingToStartAuthentication, .authenticating, .closing: - self.state = .error(.unexpectedBackendMessage(.parameterStatus(status))) - return .wait + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterStatus(status))) case .authenticated(let keyData, var parameters): return self.avoidingStateMachineCoW { machine in parameters[status.parameter] = status.value @@ -389,8 +386,6 @@ struct ConnectionStateMachine { machine.state = .closeCommand(closeState, connectionContext) return .wait } - case .error(_): - return .wait case .initialized, .closed: preconditionFailure("We shouldn't receive messages if we are not connected") @@ -406,8 +401,7 @@ struct ConnectionStateMachine { .sslHandlerAdded, .waitingToStartAuthentication, .authenticated, - .readyForQuery, - .error: + .readyForQuery: return self.closeConnectionAndCleanup(.server(errorMessage)) case .authenticating(var authState): if authState.isComplete { @@ -477,8 +471,6 @@ struct ConnectionStateMachine { let action = closeState.errorHappened(error) return self.modify(with: action) } - case .error: - return .wait case .closing: // If the state machine is in state `.closing`, the connection shutdown was initiated // by the client. This means a `TERMINATE` message has already been sent and the @@ -553,40 +545,54 @@ struct ConnectionStateMachine { } mutating func enqueue(task: PSQLTask) -> ConnectionAction { + let psqlErrror: PSQLError + // check if we are quiescing. if so fail task immidiatly - if case .quiescing = self.quiescingState { - switch task { - case .extendedQuery(let queryContext): - switch queryContext.query { - case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): - return .failQuery(eventLoopPromise, with: .connectionQuiescing, cleanupContext: nil) - case .prepareStatement(_, _, let eventLoopPromise): - return .failPreparedStatementCreation(eventLoopPromise, with: .connectionQuiescing, cleanupContext: nil) - } + switch self.quiescingState { + case .quiescing: + psqlErrror = PSQLError.clientClosesConnection(underlying: nil) + + case .notQuiescing: + switch self.state { + case .initialized, + .authenticated, + .authenticating, + .closeCommand, + .extendedQuery, + .sslNegotiated, + .sslHandlerAdded, + .sslRequestSent, + .waitingToStartAuthentication: + self.taskQueue.append(task) + return .wait + + case .readyForQuery: + return self.executeTask(task) + + case .closing(let error): + psqlErrror = PSQLError.clientClosesConnection(underlying: error) + + case .closed(clientInitiated: true, error: let error): + psqlErrror = PSQLError.clientClosedConnection(underlying: error) - case .closeCommand(let closeContext): - return .failClose(closeContext, with: .connectionQuiescing, cleanupContext: nil) + case .closed(clientInitiated: false, error: let error): + psqlErrror = PSQLError.serverClosedConnection(underlying: error) + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") } } - switch self.state { - case .readyForQuery: - return self.executeTask(task) - case .closed: - switch task { - case .extendedQuery(let queryContext): - switch queryContext.query { - case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): - return .failQuery(eventLoopPromise, with: .connectionClosed, cleanupContext: nil) - case .prepareStatement(_, _, let eventLoopPromise): - return .failPreparedStatementCreation(eventLoopPromise, with: .connectionClosed, cleanupContext: nil) - } - case .closeCommand(let closeContext): - return .failClose(closeContext, with: .connectionClosed, cleanupContext: nil) + switch task { + case .extendedQuery(let queryContext): + switch queryContext.query { + case .executeStatement(_, let promise), .unnamed(_, let promise): + return .failQuery(promise, with: psqlErrror, cleanupContext: nil) + case .prepareStatement(_, _, let promise): + return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } - default: - self.taskQueue.append(task) - return .wait + case .closeCommand(let closeContext): + return .failClose(closeContext, with: psqlErrror, cleanupContext: nil) } } @@ -601,7 +607,6 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .closeCommand, - .error, .closing, .closed: return .wait @@ -648,8 +653,6 @@ struct ConnectionStateMachine { machine.state = .closeCommand(closeState, connectionContext) return machine.modify(with: action) } - case .error: - return .read case .closing: return .read case .closed: @@ -818,7 +821,7 @@ struct ConnectionStateMachine { } } - private mutating func closeConnectionAndCleanup(_ error: PSQLError) -> ConnectionAction { + private mutating func closeConnectionAndCleanup(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction { switch self.state { case .initialized, .sslRequestSent, @@ -827,12 +830,12 @@ struct ConnectionStateMachine { .waitingToStartAuthentication, .authenticated, .readyForQuery: - let cleanupContext = self.setErrorAndCreateCleanupContext(error) + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) return .closeConnectionAndCleanup(cleanupContext) case .authenticating(var authState): - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + if authState.isComplete { // in case the auth state machine is complete all necessary actions have already // been forwarded to the consumer. We can close and cleanup without caring about the @@ -847,8 +850,8 @@ struct ConnectionStateMachine { return .closeConnectionAndCleanup(cleanupContext) case .extendedQuery(var queryStateMachine, _): - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + if queryStateMachine.isComplete { // in case the query state machine is complete all necessary actions have already // been forwarded to the consumer. We can close and cleanup without caring about the @@ -867,19 +870,23 @@ struct ConnectionStateMachine { .wait, .read: preconditionFailure("Invalid state: \(self.state)") + case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) + case .failQuery(let queryContext, with: let error): return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) + case .forwardStreamError(let error, let read): return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + case .failPreparedStatementCreation(let promise, with: let error): return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) } case .closeCommand(var closeStateMachine, _): - let cleanupContext = self.setErrorAndCreateCleanupContext(error) - + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + if closeStateMachine.isComplete { // in case the close state machine is complete all necessary actions have already // been forwarded to the consumer. We can close and cleanup without caring about the @@ -897,7 +904,7 @@ struct ConnectionStateMachine { return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) } - case .error, .closing, .closed: + case .closing, .closed: // We might run into this case because of reentrancy. For example: After we received an // backend unexpected message, that we read of the wire, we bring this connection into // the error state and will try to close the connection. However the server might have @@ -921,7 +928,7 @@ struct ConnectionStateMachine { // if we don't have anything left to do and we are quiescing, next we should close if case .quiescing(let promise) = self.quiescingState { - self.state = .closing + self.state = .closing(nil) return .closeConnection(promise) } @@ -1024,9 +1031,9 @@ extension ConnectionStateMachine { } return false - case .connectionQuiescing: + case .clientClosesConnection, .clientClosedConnection: preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") - case .connectionClosed: + case .serverClosedConnection: preconditionFailure("Pure client error, that is thrown directly and should never ") } } @@ -1039,23 +1046,28 @@ extension ConnectionStateMachine { return self.setErrorAndCreateCleanupContext(error) } - mutating func setErrorAndCreateCleanupContext(_ error: PSQLError) -> ConnectionAction.CleanUpContext { + mutating func setErrorAndCreateCleanupContext(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction.CleanUpContext { let tasks = Array(self.taskQueue) self.taskQueue.removeAll() - var closePromise: EventLoopPromise? = nil - if case .quiescing(let promise) = self.quiescingState { - closePromise = promise + var forwardedPromise: EventLoopPromise? = nil + if case .quiescing(.some(let quiescePromise)) = self.quiescingState, let closePromise = closePromise { + quiescePromise.futureResult.cascade(to: closePromise) + forwardedPromise = quiescePromise + } else if case .quiescing(.some(let quiescePromise)) = self.quiescingState { + forwardedPromise = quiescePromise + } else { + forwardedPromise = closePromise } - - self.state = .error(error) - + + self.state = .closing(error) + var action = ConnectionAction.CleanUpContext.Action.close if case .uncleanShutdown = error.code.base { action = .fireChannelInactive } - return .init(action: action, tasks: tasks, error: error, closePromise: closePromise) + return .init(action: action, tasks: tasks, error: error, closePromise: forwardedPromise) } } @@ -1187,8 +1199,6 @@ extension ConnectionStateMachine.State: CustomDebugStringConvertible { return ".extendedQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .closeCommand(let subStateMachine, let connectionContext): return ".closeCommand(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" - case .error(let error): - return ".error(\(String(reflecting: error)))" case .closing: return ".closing" case .closed: diff --git a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift index c7f92428..89f40469 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift @@ -36,7 +36,14 @@ struct ListenStateMachine { } mutating func stopListeningSucceeded(channel: String) -> StopListeningSuccessAction { - return self.channels[channel, default: .init()].stopListeningSucceeded() + switch self.channels[channel]!.stopListeningSucceeded() { + case .none: + self.channels.removeValue(forKey: channel) + return .none + + case .startListening: + return .startListening + } } enum CancelAction { @@ -46,7 +53,7 @@ struct ListenStateMachine { } mutating func cancelNotificationListener(channel: String, id: Int) -> CancelAction { - return self.channels[channel, default: .init()].cancelListening(id: id) + return self.channels[channel]?.cancelListening(id: id) ?? .none } mutating func fail(_ error: Error) -> [NotificationListener] { diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 5d9e534c..1fec59b1 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -18,8 +18,9 @@ public struct PSQLError: Error { case queryCancelled case tooManyParameters - case connectionQuiescing - case connectionClosed + case clientClosesConnection + case clientClosedConnection + case serverClosedConnection case connectionError case uncleanShutdown @@ -45,13 +46,20 @@ public struct PSQLError: Error { public static let invalidCommandTag = Self(.invalidCommandTag) public static let queryCancelled = Self(.queryCancelled) public static let tooManyParameters = Self(.tooManyParameters) - public static let connectionQuiescing = Self(.connectionQuiescing) - public static let connectionClosed = Self(.connectionClosed) + public static let clientClosesConnection = Self(.clientClosesConnection) + public static let clientClosedConnection = Self(.clientClosedConnection) + public static let serverClosedConnection = Self(.serverClosedConnection) public static let connectionError = Self(.connectionError) public static let uncleanShutdown = Self.init(.uncleanShutdown) public static let listenFailed = Self.init(.listenFailed) public static let unlistenFailed = Self.init(.unlistenFailed) + @available(*, deprecated, renamed: "clientClosesConnection") + public static let connectionQuiescing = Self.clientClosesConnection + + @available(*, deprecated, message: "Use the more specific `serverClosedConnection` or `clientClosedConnection` instead") + public static let connectionClosed = Self.serverClosedConnection + public var description: String { switch self.base { case .sslUnsupported: @@ -78,10 +86,12 @@ public struct PSQLError: Error { return "queryCancelled" case .tooManyParameters: return "tooManyParameters" - case .connectionQuiescing: - return "connectionQuiescing" - case .connectionClosed: - return "connectionClosed" + case .clientClosesConnection: + return "clientClosesConnection" + case .clientClosedConnection: + return "clientClosedConnection" + case .serverClosedConnection: + return "serverClosedConnection" case .connectionError: return "connectionError" case .uncleanShutdown: @@ -377,19 +387,33 @@ public struct PSQLError: Error { return new } - static var connectionQuiescing: PSQLError { PSQLError(code: .connectionQuiescing) } + static func clientClosesConnection(underlying: Error?) -> PSQLError { + var error = PSQLError(code: .clientClosesConnection) + error.underlying = underlying + return error + } + + static func clientClosedConnection(underlying: Error?) -> PSQLError { + var error = PSQLError(code: .clientClosedConnection) + error.underlying = underlying + return error + } - static var connectionClosed: PSQLError { PSQLError(code: .connectionClosed) } + static func serverClosedConnection(underlying: Error?) -> PSQLError { + var error = PSQLError(code: .serverClosedConnection) + error.underlying = underlying + return error + } - static var authMechanismRequiresPassword: PSQLError { PSQLError(code: .authMechanismRequiresPassword) } + static let authMechanismRequiresPassword = PSQLError(code: .authMechanismRequiresPassword) - static var sslUnsupported: PSQLError { PSQLError(code: .sslUnsupported) } + static let sslUnsupported = PSQLError(code: .sslUnsupported) - static var queryCancelled: PSQLError { PSQLError(code: .queryCancelled) } + static let queryCancelled = PSQLError(code: .queryCancelled) - static var uncleanShutdown: PSQLError { PSQLError(code: .uncleanShutdown) } + static let uncleanShutdown = PSQLError(code: .uncleanShutdown) - static var receivedUnencryptedDataAfterSSLRequest: PSQLError { PSQLError(code: .receivedUnencryptedDataAfterSSLRequest) } + static let receivedUnencryptedDataAfterSSLRequest = PSQLError(code: .receivedUnencryptedDataAfterSSLRequest) static func server(_ response: PostgresBackendMessage.ErrorResponse) -> PSQLError { var error = PSQLError(code: .server) diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift index 3233fb77..2bf0d6d8 100644 --- a/Sources/PostgresNIO/New/PSQLEventsHandler.swift +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -7,6 +7,8 @@ enum PSQLOutgoingEvent { /// /// this shall be removed with the next breaking change and always supplied with `PSQLConnection.Configuration` case authenticate(AuthContext) + + case gracefulShutdown } enum PSQLEvent { diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index abfa5aeb..7801d4d6 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -247,7 +247,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { return } - let action = self.state.close(promise) + let action = self.state.close(promise: promise) self.run(action, with: context) } @@ -258,6 +258,11 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case PSQLOutgoingEvent.authenticate(let authContext): let action = self.state.provideAuthenticationContext(authContext) self.run(action, with: context) + + case PSQLOutgoingEvent.gracefulShutdown: + let action = self.state.gracefulClose(promise) + self.run(action, with: context) + default: context.triggerUserOutboundEvent(event, promise: promise) } diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 10970b26..1989e5bc 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -37,9 +37,9 @@ extension PSQLError { return self.underlying ?? self case .tooManyParameters, .invalidCommandTag: return self - case .connectionQuiescing: - return PostgresError.connectionClosed - case .connectionClosed: + case .clientClosesConnection, + .clientClosedConnection, + .serverClosedConnection: return PostgresError.connectionClosed case .connectionError: return self.underlying ?? self diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index 5fd3bc20..f3d72a5e 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -137,14 +137,14 @@ class ConnectionStateMachineTests: XCTestCase { func testErrorIsIgnoredWhenClosingConnection() { // test ignore unclean shutdown when closing connection - var stateIgnoreChannelError = ConnectionStateMachine(.closing) - + var stateIgnoreChannelError = ConnectionStateMachine(.closing(nil)) + XCTAssertEqual(stateIgnoreChannelError.errorHappened(.connectionError(underlying: NIOSSLError.uncleanShutdown)), .wait) XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive) // test ignore any other error when closing connection - var stateIgnoreErrorMessage = ConnectionStateMachine(.closing) + var stateIgnoreErrorMessage = ConnectionStateMachine(.closing(nil)) XCTAssertEqual(stateIgnoreErrorMessage.errorReceived(.init(fields: [:])), .wait) XCTAssertEqual(stateIgnoreErrorMessage.closed(), .fireChannelInactive) } diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 1af35fac..d6d03107 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -22,13 +22,13 @@ final class PSQLRowStreamTests: XCTestCase { func testFailedStream() { let stream = PSQLRowStream( - source: .noRows(.failure(PSQLError.connectionClosed)), + source: .noRows(.failure(PSQLError.serverClosedConnection(underlying: nil))), eventLoop: self.eventLoop, logger: self.logger ) XCTAssertThrowsError(try stream.all().wait()) { - XCTAssertEqual($0 as? PSQLError, .connectionClosed) + XCTAssertEqual($0 as? PSQLError, .serverClosedConnection(underlying: nil)) } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index 5388e8b5..eed5ada7 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -24,8 +24,11 @@ class PostgresChannelHandlerTests: XCTestCase { ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), handler ], loop: self.eventLoop) - defer { XCTAssertNoThrow(try embedded.finish()) } - + defer { + do { try embedded.finish() } + catch { print("\(String(reflecting: error))") } + } + var maybeMessage: PostgresFrontendMessage? XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil)) XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self)) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 0622d51e..46f864ce 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -182,6 +182,98 @@ class PostgresConnectionTests: XCTestCase { } } + func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + let rows = try await connection.query("SELECT 1;", logger: self.logger) + var iterator = rows.decode(Int.self).makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first, 1) + let second = try await iterator.next() + XCTAssertNil(second) + } + } + + for i in 0...1 { + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + if i == 0 { + taskGroup.addTask { + try await connection.closeGracefully() + } + } + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + let intDescription = RowDescription.Column( + name: "", + tableOID: 0, + columnAttributeNumber: 0, + dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary + ) + try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + + let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(terminate, .terminate) + try await channel.closeFuture.get() + XCTAssertEqual(channel.isActive, false) + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + } + + func testCloseClosesImmediatly() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + try await connection.query("SELECT 1;", logger: self.logger) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + async let close: () = connection.close() + + try await channel.closeFuture.get() + XCTAssertEqual(channel.isActive, false) + + try await close + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + XCTFail("Expected queries to fail") + case .failure(let failure): + guard let error = failure as? PSQLError else { + return XCTFail("Unexpected error type: \(failure)") + } + XCTAssertEqual(error.code, .clientClosedConnection) + } + } + } + } func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index fc589c0b..872c098d 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -183,7 +183,7 @@ final class PostgresRowSequenceTests: XCTestCase { logger: self.logger ) - stream.receive(completion: .failure(PSQLError.connectionClosed)) + stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) let rowSequence = stream.asyncSequence() @@ -194,7 +194,7 @@ final class PostgresRowSequenceTests: XCTestCase { } XCTFail("Expected that an error was thrown before.") } catch { - XCTAssertEqual(error as? PSQLError, .connectionClosed) + XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) } } @@ -255,14 +255,14 @@ final class PostgresRowSequenceTests: XCTestCase { XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { - stream.receive(completion: .failure(PSQLError.connectionClosed)) + stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) } do { _ = try await rowIterator.next() XCTFail("Expected that an error was thrown before.") } catch { - XCTAssertEqual(error as? PSQLError, .connectionClosed) + XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) } } From 5217ba7557f8aa292fcf5f0440bfc2bed7862efb Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 10 Aug 2023 06:16:17 -0500 Subject: [PATCH 176/292] Use README header image compatible with light/dark mode (#393) --- README.md | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 441a41e3..b4f8f70e 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,18 @@ -PostgresNIO - -[![SSWG Incubating Badge](https://img.shields.io/badge/sswg-incubating-green.svg)][SSWG Incubation] -[![Documentation](http://img.shields.io/badge/read_the-docs-2196f3.svg)][Documentation] -[![Team Chat](https://img.shields.io/discord/431917998102675485.svg)][Team Chat] -[![MIT License](http://img.shields.io/badge/license-MIT-brightgreen.svg)][MIT License] -[![Continuous Integration](https://github.com/vapor/postgres-nio/actions/workflows/test.yml/badge.svg)][Continuous Integration] -[![Swift 5.6](http://img.shields.io/badge/swift-5.6-brightgreen.svg)][Swift 5.6] +

+ + + + PostgresNIO +

- +SSWG Incubation +Documentation +MIT License +Continuous Integration +Swift 5.6 +

+
🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. Features: From d5c52584cb3f19b3166040e05271f7581b0befa3 Mon Sep 17 00:00:00 2001 From: Mario Sangiorgio Date: Fri, 18 Aug 2023 11:12:18 +0100 Subject: [PATCH 177/292] async/await prepared statement API (#390) This patch adds a new `PreparedStatement` protocol to represent prepared statements and an `execute` function on `PostgresConnection` to prepare and execute statements. To implement the features the patch also introduces a `PreparedStatementStateMachine` that keeps track of the state of a prepared statement at the connection level. This ensures that, for each connection, each statement is prepared once at time of first use and then subsequent uses are going to skip the preparation step and just execute it. ## Example usage First define the struct to represent the prepared statement: ```swift struct ExamplePreparedStatement: PreparedStatement { static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" typealias Row = (Int, String) var state: String func makeBindings() -> PostgresBindings { var bindings = PostgresBindings() bindings.append(self.state) return bindings } func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { try row.decode(Row.self) } } ``` then, assuming you already have a `PostgresConnection` you can execute it: ```swift let preparedStatement = ExamplePreparedStatement(state: "active") let results = try await connection.execute(preparedStatement, logger: logger) for (pid, database) in results { print("PID: \(pid), database: \(database)") } ``` --------- Co-authored-by: Fabian Fett --- .../Connection/PostgresConnection.swift | 66 ++++ .../PreparedStatementStateMachine.swift | 93 +++++ Sources/PostgresNIO/New/PSQLTask.swift | 23 ++ .../New/PostgresChannelHandler.swift | 115 +++++- Sources/PostgresNIO/New/PostgresQuery.swift | 10 + .../PostgresNIO/New/PreparedStatement.swift | 40 ++ Tests/IntegrationTests/AsyncTests.swift | 42 +++ .../PreparedStatementStateMachineTests.swift | 159 ++++++++ .../PSQLFrontendMessageDecoder.swift | 2 +- .../New/PostgresConnectionTests.swift | 352 ++++++++++++++++++ 10 files changed, 898 insertions(+), 4 deletions(-) create mode 100644 Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift create mode 100644 Sources/PostgresNIO/New/PreparedStatement.swift create mode 100644 Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 7ac8ec57..d3f51ca9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -460,6 +460,72 @@ extension PostgresConnection { self.channel.write(task, promise: nil) } } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { + let bindings = try preparedStatement.makeBindings() + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + do { + return try await promise.futureResult + .map { $0.asyncSequence() } + .get() + .map { try preparedStatement.decodeRow($0) } + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + + } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> String where Statement.Row == () { + let bindings = try preparedStatement.makeBindings() + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + do { + return try await promise.futureResult + .map { $0.commandTag } + .get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift new file mode 100644 index 00000000..5afa4d0b --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -0,0 +1,93 @@ +import NIOCore + +struct PreparedStatementStateMachine { + enum State { + case preparing([PreparedStatementContext]) + case prepared(RowDescription?) + case error(PSQLError) + } + + var preparedStatements: [String: State] = [:] + + enum LookupAction { + case prepareStatement + case waitForAlreadyInFlightPreparation + case executeStatement(RowDescription?) + case returnError(PSQLError) + } + + mutating func lookup(preparedStatement: PreparedStatementContext) -> LookupAction { + if let state = self.preparedStatements[preparedStatement.name] { + switch state { + case .preparing(var statements): + statements.append(preparedStatement) + self.preparedStatements[preparedStatement.name] = .preparing(statements) + return .waitForAlreadyInFlightPreparation + case .prepared(let rowDescription): + return .executeStatement(rowDescription) + case .error(let error): + return .returnError(error) + } + } else { + self.preparedStatements[preparedStatement.name] = .preparing([preparedStatement]) + return .prepareStatement + } + } + + struct PreparationCompleteAction { + var statements: [PreparedStatementContext] + var rowDescription: RowDescription? + } + + mutating func preparationComplete( + name: String, + rowDescription: RowDescription? + ) -> PreparationCompleteAction { + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + // When sending the bindings we are going to ask for binary data. + if var rowDescription = rowDescription { + for i in 0.. ErrorHappenedAction { + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + self.preparedStatements[name] = .error(error) + return ErrorHappenedAction( + statements: statements, + error: error + ) + case .prepared, .error: + preconditionFailure("Error happened in an unexpected state \(state)") + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index f5de6561..9425c12b 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -6,6 +6,7 @@ enum HandlerTask { case closeCommand(CloseCommandContext) case startListening(NotificationListener) case cancelListening(String, Int) + case executePreparedStatement(PreparedStatementContext) } enum PSQLTask { @@ -69,6 +70,28 @@ final class ExtendedQueryContext { } } +final class PreparedStatementContext{ + let name: String + let sql: String + let bindings: PostgresBindings + let logger: Logger + let promise: EventLoopPromise + + init( + name: String, + sql: String, + bindings: PostgresBindings, + logger: Logger, + promise: EventLoopPromise + ) { + self.name = name + self.sql = sql + self.bindings = bindings + self.logger = logger + self.promise = promise + } +} + final class CloseCommandContext { let target: CloseTarget let logger: Logger diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 7801d4d6..bf56d6d1 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -22,7 +22,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? - private var listenState: ListenStateMachine + private var listenState = ListenStateMachine() + private var preparedStatementState = PreparedStatementStateMachine() init( configuration: PostgresConnection.InternalConfiguration, @@ -32,7 +33,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop - self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -50,7 +50,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) { self.state = state self.eventLoop = eventLoop - self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -233,6 +232,29 @@ final class PostgresChannelHandler: ChannelDuplexHandler { listener.failed(CancellationError()) return } + case .executePreparedStatement(let preparedStatement): + let action = self.preparedStatementState.lookup( + preparedStatement: preparedStatement + ) + switch action { + case .prepareStatement: + psqlTask = self.makePrepareStatementTask( + preparedStatement: preparedStatement, + context: context + ) + case .waitForAlreadyInFlightPreparation: + // The state machine already keeps track of this + // and will execute the statement as soon as it's prepared + return + case .executeStatement(let rowDescription): + psqlTask = self.makeExecutePreparedStatementTask( + preparedStatement: preparedStatement, + rowDescription: rowDescription + ) + case .returnError(let error): + preparedStatement.promise.fail(error) + return + } } let action = self.state.enqueue(task: psqlTask) @@ -664,6 +686,93 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } + private func makePrepareStatementTask( + preparedStatement: PreparedStatementContext, + context: ChannelHandlerContext + ) -> PSQLTask { + let promise = self.eventLoop.makePromise(of: RowDescription?.self) + promise.futureResult.whenComplete { result in + switch result { + case .success(let rowDescription): + self.prepareStatementComplete( + name: preparedStatement.name, + rowDescription: rowDescription, + context: context + ) + case .failure(let error): + let psqlError: PSQLError + if let error = error as? PSQLError { + psqlError = error + } else { + psqlError = .connectionError(underlying: error) + } + self.prepareStatementFailed( + name: preparedStatement.name, + error: psqlError, + context: context + ) + } + } + return .extendedQuery(.init( + name: preparedStatement.name, + query: preparedStatement.sql, + logger: preparedStatement.logger, + promise: promise + )) + } + + private func makeExecutePreparedStatementTask( + preparedStatement: PreparedStatementContext, + rowDescription: RowDescription? + ) -> PSQLTask { + return .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + } + + private func prepareStatementComplete( + name: String, + rowDescription: RowDescription?, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.preparationComplete( + name: name, + rowDescription: rowDescription + ) + for preparedStatement in action.statements { + let action = self.state.enqueue(task: .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: action.rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + ) + self.run(action, with: context) + } + } + + private func prepareStatementFailed( + name: String, + error: PSQLError, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.errorHappened( + name: name, + error: error + ) + for statement in action.statements { + statement.promise.fail(action.error) + } + } } extension PostgresChannelHandler: PSQLRowsDataSource { diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 2e06e1d9..4ca1e454 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -167,6 +167,11 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(dataType: .null, format: .binary, protected: true)) } + @inlinable + public mutating func append(_ value: Value) throws { + try self.append(value, context: .default) + } + @inlinable public mutating func append( _ value: Value, @@ -176,6 +181,11 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append(_ value: Value) { + self.append(value, context: .default) + } + @inlinable public mutating func append( _ value: Value, diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift new file mode 100644 index 00000000..1e0b5d5a --- /dev/null +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -0,0 +1,40 @@ +/// A prepared statement. +/// +/// Structs conforming to this protocol will need to provide the SQL statement to +/// send to the server and a way of creating bindings are decoding the result. +/// +/// As an example, consider this struct: +/// ```swift +/// struct Example: PostgresPreparedStatement { +/// static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" +/// typealias Row = (Int, String) +/// +/// var state: String +/// +/// func makeBindings() -> PostgresBindings { +/// var bindings = PostgresBindings() +/// bindings.append(self.state) +/// return bindings +/// } +/// +/// func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { +/// try row.decode(Row.self) +/// } +/// } +/// ``` +/// +/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, +/// which will take care of preparing the statement on the server side and executing it. +public protocol PostgresPreparedStatement: Sendable { + /// The type rows returned by the statement will be decoded into + associatedtype Row + + /// The SQL statement to prepare on the database server. + static var sql: String { get } + + /// Make the bindings to provided concrete values to use when executing the prepared SQL statement + func makeBindings() throws -> PostgresBindings + + /// Decode a row returned by the database into an instance of `Row` + func decodeRow(_ row: PostgresRow) throws -> Row +} diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index f68ef1f3..bf945a67 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -315,6 +315,48 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await connection.query("SELECT 1;", logger: .psqlTest) } } + + func testPreparedStatement() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct TestPreparedStatement: PostgresPreparedStatement { + static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + typealias Row = (Int, String) + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + let preparedStatement = TestPreparedStatement(state: "active") + try await withTestConnection(on: eventLoop) { connection in + var results = try await connection.execute(preparedStatement, logger: .psqlTest) + var counter = 0 + + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + + // Second execution, which reuses the existing prepared statement + results = try await connection.execute(preparedStatement, logger: .psqlTest) + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + } + } } extension XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift new file mode 100644 index 00000000..ab77a57c --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -0,0 +1,159 @@ +import XCTest +import NIOEmbedded +@testable import PostgresNIO + +class PreparedStatementStateMachineTests: XCTestCase { + func testPrepareAndExecuteStatement() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + XCTAssertNil(preparationCompleteAction.rowDescription) + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // The statement is already preparead, lookups tell us to execute it + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .executeStatement(nil) = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + func testPrepareAndExecuteStatementWithError() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Simulate an error occurring during preparation + let error = PSQLError(code: .server) + let preparationCompleteAction = stateMachine.errorHappened( + name: "test", + error: error + ) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + firstPreparedStatement.promise.fail(error) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Ensure that we don't try again to prepare a statement we know will fail + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .returnError = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.fail(error) + } + + func testBatchStatementPreparation() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // A new request comes in before the statement completes + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .waitForAlreadyInFlightPreparation = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state. + // The action tells us to execute both the pending statements. + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 2) + XCTAssertNil(preparationCompleteAction.rowDescription) + + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + private func makePreparedStatementContext(eventLoop: EmbeddedEventLoop) -> PreparedStatementContext { + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + return PreparedStatementContext( + name: "test", + sql: "INSERT INTO test_table (column1) VALUES (1)", + bindings: PostgresBindings(), + logger: .psqlTest, + promise: promise + ) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index b9677000..46c043b1 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -142,7 +142,7 @@ extension PostgresFrontendMessage { } let parameters = (0.. ByteBuffer? in - let length = buffer.readInteger(as: UInt16.self) + let length = buffer.readInteger(as: UInt32.self) switch length { case .some(..<0): return nil diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 46f864ce..9c4dc5cb 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -275,6 +275,288 @@ class PostgresConnectionTests: XCTestCase { } } + struct TestPrepareStatement: PostgresPreparedStatement { + static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" + typealias Row = String + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(.init(string: self.state)) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + func testPreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest = try await channel.waitForPreparedRequest() + XCTAssertEqual(preparedRequest.bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) + XCTAssertEqual(preparedRequest.bind.parameters.count, 1) + XCTAssertEqual(preparedRequest.bind.resultColumnFormats, [.binary]) + + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + } + } + + func testSerialExecutionOfSamePreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter1, "active") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + + // Now that the statement has been prepared and executed, send another request that will only get executed + // without preparation + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter2, "idle") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Let them race to tests that requests and responses aren't mixed up + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_active", database) + } + XCTAssertEqual(rows, 1) + } + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_idle", database) + } + XCTAssertEqual(rows, 1) + } + + // The channel deduplicates prepare requests, we're going to see only one of them + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + // Now both the tasks have their statements prepared. + // We should see both of their execute requests coming in, the order is nondeterministic + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter1)"] + ], + commandTag: TestPrepareStatement.sql + ) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter2)"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationFailure() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + // Respond with an error taking care to return a SQLSTATE that isn't + // going to get the connection closed. + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .sqlState : "26000" // invalid_sql_statement_name + ]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } + + + // Send another requests with the same prepared statement, which should fail straight + // away without any interaction with the server + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [ @@ -327,6 +609,66 @@ extension NIOAsyncTestingChannel { return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) } + + func waitForPrepareRequest() async throws -> PrepareRequest { + let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .parse(let parse) = parse, + case .describe(let describe) = describe, + case .sync = sync + else { + fatalError("Unexpected message") + } + + return PrepareRequest(parse: parse, describe: describe) + } + + func sendPrepareResponse( + parameterDescription: PostgresBackendMessage.ParameterDescription, + rowDescription: RowDescription + ) async throws { + try await self.writeInbound(PostgresBackendMessage.parseComplete) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.parameterDescription(parameterDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.rowDescription(rowDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } + + func waitForPreparedRequest() async throws -> PreparedRequest { + let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .bind(let bind) = bind, + case .execute(let execute) = execute, + case .sync = sync + else { + fatalError() + } + + return PreparedRequest(bind: bind, execute: execute) + } + + func sendPreparedResponse( + dataRows: [DataRow], + commandTag: String + ) async throws { + try await self.writeInbound(PostgresBackendMessage.bindComplete) + try await self.testingEventLoop.executeInContext { self.read() } + for dataRow in dataRows { + try await self.writeInbound(PostgresBackendMessage.dataRow(dataRow)) + } + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.commandComplete(commandTag)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } } struct UnpreparedRequest { @@ -335,3 +677,13 @@ struct UnpreparedRequest { var bind: PostgresFrontendMessage.Bind var execute: PostgresFrontendMessage.Execute } + +struct PrepareRequest { + var parse: PostgresFrontendMessage.Parse + var describe: PostgresFrontendMessage.Describe +} + +struct PreparedRequest { + var bind: PostgresFrontendMessage.Bind + var execute: PostgresFrontendMessage.Execute +} From ef3a00f9dfd79ad5cd40a0a9fa242e8d3169cf2f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Aug 2023 13:39:47 +0200 Subject: [PATCH 178/292] Cleanup encoding Startup message (#395) Further cleanup of message encoding: - Move Startup struct into PostgresFrontendMessageEncoder - Move PSQLMessagePayloadEncodable into tests, since it isn't used in PostgresNIO anymore - Only support the parameters that are actually used in encoding startup messages --- .../PostgresNIO/New/Messages/Startup.swift | 52 ------------ .../New/PostgresChannelHandler.swift | 13 +-- .../New/PostgresFrontendMessage.swift | 48 ++++++++++- .../New/PostgresFrontendMessageEncoder.swift | 22 +---- .../PSQLBackendMessageEncoder.swift | 4 + .../New/Messages/StartupTests.swift | 82 ++++++++----------- .../New/PostgresChannelHandlerTests.swift | 11 +++ 7 files changed, 98 insertions(+), 134 deletions(-) delete mode 100644 Sources/PostgresNIO/New/Messages/Startup.swift diff --git a/Sources/PostgresNIO/New/Messages/Startup.swift b/Sources/PostgresNIO/New/Messages/Startup.swift deleted file mode 100644 index 16d23e09..00000000 --- a/Sources/PostgresNIO/New/Messages/Startup.swift +++ /dev/null @@ -1,52 +0,0 @@ -import NIOCore - -extension PostgresFrontendMessage { - struct Startup: Hashable { - static let versionThree: Int32 = 0x00_03_00_00 - - /// Creates a `Startup` with "3.0" as the protocol version. - static func versionThree(parameters: Parameters) -> Startup { - return .init(protocolVersion: Self.versionThree, parameters: parameters) - } - - /// The protocol version number. The most significant 16 bits are the major - /// version number (3 for the protocol described here). The least significant - /// 16 bits are the minor version number (0 for the protocol described here). - var protocolVersion: Int32 - - /// The protocol version number is followed by one or more pairs of parameter - /// name and value strings. A zero byte is required as a terminator after - /// the last name/value pair. `user` is required, others are optional. - struct Parameters: Hashable { - enum Replication { - case `true` - case `false` - case database - } - - /// The database user name to connect as. Required; there is no default. - var user: String - - /// The database to connect to. Defaults to the user name. - var database: String? - - /// Command-line arguments for the backend. (This is deprecated in favor - /// of setting individual run-time parameters.) Spaces within this string are - /// considered to separate arguments, unless escaped with a - /// backslash (\); write \\ to represent a literal backslash. - var options: String? - - /// Used to connect in streaming replication mode, where a small set of - /// replication commands can be issued instead of SQL statements. Value - /// can be true, false, or database, and the default is false. - var replication: Replication - } - var parameters: Parameters - - /// Creates a new `PostgreSQLStartupMessage`. - init(protocolVersion: Int32, parameters: Parameters) { - self.protocolVersion = protocolVersion - self.parameters = parameters - } - } -} diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index bf56d6d1..7b31a776 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - self.encoder.startup(authContext.toStartupParameters()) + self.encoder.startup(user: authContext.username, database: authContext.database) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSSLRequest: self.encoder.ssl() @@ -793,17 +793,6 @@ extension PostgresChannelHandler: PSQLRowsDataSource { } } -extension AuthContext { - func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters { - PostgresFrontendMessage.Startup.Parameters( - user: self.username, - database: self.database, - options: nil, - replication: .false - ) - } -} - private extension Insecure.MD5.Digest { private static let lowercaseLookup: [UInt8] = [ diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift index 2a7ec9f1..ef7ce8f8 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessage.swift @@ -102,6 +102,50 @@ enum PostgresFrontendMessage: Equatable { static let requestCode: Int32 = 80877103 } + struct Startup: Hashable { + static let versionThree: Int32 = 0x00_03_00_00 + + /// Creates a `Startup` with "3.0" as the protocol version. + static func versionThree(parameters: Parameters) -> Startup { + return .init(protocolVersion: Self.versionThree, parameters: parameters) + } + + /// The protocol version number. The most significant 16 bits are the major + /// version number (3 for the protocol described here). The least significant + /// 16 bits are the minor version number (0 for the protocol described here). + var protocolVersion: Int32 + + /// The protocol version number is followed by one or more pairs of parameter + /// name and value strings. A zero byte is required as a terminator after + /// the last name/value pair. `user` is required, others are optional. + struct Parameters: Hashable { + enum Replication { + case `true` + case `false` + case database + } + + /// The database user name to connect as. Required; there is no default. + var user: String + + /// The database to connect to. Defaults to the user name. + var database: String? + + /// Command-line arguments for the backend. (This is deprecated in favor + /// of setting individual run-time parameters.) Spaces within this string are + /// considered to separate arguments, unless escaped with a + /// backslash (\); write \\ to represent a literal backslash. + var options: String? + + /// Used to connect in streaming replication mode, where a small set of + /// replication commands can be issued instead of SQL statements. Value + /// can be true, false, or database, and the default is false. + var replication: Replication + } + + var parameters: Parameters + } + case bind(Bind) case cancel(Cancel) case close(Close) @@ -225,7 +269,3 @@ extension PostgresFrontendMessage { } } } - -protocol PSQLMessagePayloadEncodable { - func encode(into buffer: inout ByteBuffer) -} diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 46dbba42..d4747163 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -13,34 +13,18 @@ struct PostgresFrontendMessageEncoder { self.buffer = buffer } - mutating func startup(_ parameters: PostgresFrontendMessage.Startup.Parameters) { + mutating func startup(user: String, database: String?) { self.clearIfNeeded() self.encodeLengthPrefixed { buffer in buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree) buffer.writeNullTerminatedString("user") - buffer.writeNullTerminatedString(parameters.user) + buffer.writeNullTerminatedString(user) - if let database = parameters.database { + if let database = database { buffer.writeNullTerminatedString("database") buffer.writeNullTerminatedString(database) } - if let options = parameters.options { - buffer.writeNullTerminatedString("options") - buffer.writeNullTerminatedString(options) - } - - switch parameters.replication { - case .database: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("replication") - case .true: - buffer.writeNullTerminatedString("replication") - buffer.writeNullTerminatedString("true") - case .false: - break - } - buffer.writeInteger(UInt8(0)) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index e51c14f9..9614bf1e 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -257,3 +257,7 @@ extension RowDescription: PSQLMessagePayloadEncodable { } } } + +protocol PSQLMessagePayloadEncodable { + func encode(into buffer: inout ByteBuffer) +} diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index e72f0f34..39e9bb42 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -4,56 +4,44 @@ import NIOCore class StartupTests: XCTestCase { - func testStartupMessage() { + func testStartupMessageWithDatabase() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() - - let replicationValues: [PostgresFrontendMessage.Startup.Parameters.Replication] = [ - .`true`, - .`false`, - .database - ] - - for replication in replicationValues { - let parameters = PostgresFrontendMessage.Startup.Parameters( - user: "test", - database: "abc123", - options: "some options", - replication: replication - ) - - encoder.startup(parameters) - byteBuffer = encoder.flushBuffer() - - let byteBufferLength = Int32(byteBuffer.readableBytes) - XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some options") - if replication != .false { - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "replication") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), replication.stringValue) - } - XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) - - XCTAssertEqual(byteBuffer.readableBytes, 0) - } + + let user = "test" + let database = "abc123" + + encoder.startup(user: user, database: database) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) } -} -extension PostgresFrontendMessage.Startup.Parameters.Replication { - var stringValue: String { - switch self { - case .true: - return "true" - case .false: - return "false" - case .database: - return "replication" - } + func testStartupMessageWithoutDatabase() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + var byteBuffer = ByteBuffer() + + let user = "test" + + encoder.startup(user: user, database: nil) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index eed5ada7..b047cd72 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -277,3 +277,14 @@ class TestEventHandler: ChannelInboundHandler { self.events.append(psqlEvent) } } + +extension AuthContext { + func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters { + PostgresFrontendMessage.Startup.Parameters( + user: self.username, + database: self.database, + options: nil, + replication: .false + ) + } +} From c1de89a187eca87eafb1ca398645845e4ed8af23 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 18 Aug 2023 16:37:59 +0200 Subject: [PATCH 179/292] Make sure correct error is thrown, if server closes connection (#397) --- .../ConnectionStateMachine.swift | 28 ++++++++++--------- .../New/PostgresConnectionTests.swift | 28 +++++++++++++++++++ 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index bbfa0faa..b7ecc461 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -203,7 +203,7 @@ struct ConnectionStateMachine { preconditionFailure("How can a connection be closed, if it was never connected.") case .closed: - preconditionFailure("How can a connection be closed, if it is already closed.") + return .wait case .authenticated, .sslRequestSent, @@ -214,8 +214,8 @@ struct ConnectionStateMachine { .readyForQuery, .extendedQuery, .closeCommand: - return self.errorHappened(.uncleanShutdown) - + return self.errorHappened(.serverClosedConnection(underlying: nil)) + case .closing(let error): self.state = .closed(clientInitiated: true, error: error) self.quiescingState = .notQuiescing @@ -910,7 +910,7 @@ struct ConnectionStateMachine { // the error state and will try to close the connection. However the server might have // send further follow up messages. In those cases we will run into this method again // and again. We should just ignore those events. - return .wait + return .closeConnection(closePromise) case .modifying: preconditionFailure("Invalid state: \(self.state)") @@ -1034,16 +1034,16 @@ extension ConnectionStateMachine { case .clientClosesConnection, .clientClosedConnection: preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") case .serverClosedConnection: - preconditionFailure("Pure client error, that is thrown directly and should never ") + return true } } mutating func setErrorAndCreateCleanupContextIfNeeded(_ error: PSQLError) -> ConnectionAction.CleanUpContext? { - guard self.shouldCloseConnection(reason: error) else { - return nil + if self.shouldCloseConnection(reason: error) { + return self.setErrorAndCreateCleanupContext(error) } - return self.setErrorAndCreateCleanupContext(error) + return nil } mutating func setErrorAndCreateCleanupContext(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction.CleanUpContext { @@ -1060,13 +1060,15 @@ extension ConnectionStateMachine { forwardedPromise = closePromise } - self.state = .closing(error) - - var action = ConnectionAction.CleanUpContext.Action.close - if case .uncleanShutdown = error.code.base { + let action: ConnectionAction.CleanUpContext.Action + if case .serverClosedConnection = error.code.base { + self.state = .closed(clientInitiated: false, error: error) action = .fireChannelInactive + } else { + self.state = .closing(error) + action = .close } - + return .init(action: action, tasks: tasks, error: error, closePromise: forwardedPromise) } } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 9c4dc5cb..59917c40 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -275,6 +275,34 @@ class PostgresConnectionTests: XCTestCase { } } + func testIfServerJustClosesTheErrorReflectsThat() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + async let response = try await connection.query("SELECT 1;", logger: self.logger) + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } + + do { + _ = try await response + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + + // retry on same connection + + do { + _ = try await connection.query("SELECT 1;", logger: self.logger) + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + } + struct TestPrepareStatement: PostgresPreparedStatement { static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" typealias Row = String From 12584c6666bd0b197e8063ef2415a7c9281152fb Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 18 Aug 2023 13:54:32 -0500 Subject: [PATCH 180/292] Fix a few inaccurate or confusing precondition failure messages (#398) --- .../ConnectionStateMachine.swift | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index b7ecc461..22c4087e 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -858,8 +858,9 @@ struct ConnectionStateMachine { // substate machine. return .closeConnectionAndCleanup(cleanupContext) } - - switch queryStateMachine.errorHappened(error) { + + let action = queryStateMachine.errorHappened(error) + switch action { case .sendParseDescribeBindExecuteSync, .sendParseDescribeSync, .sendBindExecuteSync, @@ -869,7 +870,7 @@ struct ConnectionStateMachine { .forwardStreamComplete, .wait, .read: - preconditionFailure("Invalid state: \(self.state)") + preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) @@ -894,12 +895,13 @@ struct ConnectionStateMachine { return .closeConnectionAndCleanup(cleanupContext) } - switch closeStateMachine.errorHappened(error) { + let action = closeStateMachine.errorHappened(error) + switch action { case .sendCloseSync, .succeedClose, .read, .wait: - preconditionFailure("Invalid state: \(self.state)") + preconditionFailure("Invalid close state machine action in state: \(self.state), action: \(action)") case .failClose(let closeCommandContext, with: let error): return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext) } @@ -1032,7 +1034,7 @@ extension ConnectionStateMachine { return false case .clientClosesConnection, .clientClosedConnection: - preconditionFailure("Pure client error, that is thrown directly in PostgresConnection") + preconditionFailure("A pure client error was thrown directly in PostgresConnection, this shouldn't happen") case .serverClosedConnection: return true } From 9a02d740a0fdb6fa52818c91d27875deb05add24 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 19 Aug 2023 11:10:55 +0200 Subject: [PATCH 181/292] Move PostgresFrontendMessage to tests (#399) --- .../New/Extensions/ByteBuffer+PSQL.swift | 8 -- .../New/PostgresFrontendMessageEncoder.swift | 95 +++++++++++++------ .../New/Extensions/ByteBuffer+Utils.swift | 5 +- .../Extensions}/PostgresFrontendMessage.swift | 1 + 4 files changed, 71 insertions(+), 38 deletions(-) rename {Sources/PostgresNIO/New => Tests/PostgresNIOTests/New/Extensions}/PostgresFrontendMessage.swift (99%) diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 6d632b6f..838e624d 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -2,14 +2,6 @@ import NIOCore internal extension ByteBuffer { - mutating func psqlWriteBackendMessageID(_ messageID: PostgresBackendMessage.ID) { - self.writeInteger(messageID.rawValue) - } - - mutating func psqlWriteFrontendMessageID(_ messageID: PostgresFrontendMessage.ID) { - self.writeInteger(messageID.rawValue) - } - @usableFromInline mutating func psqlReadFloat() -> Float? { return self.readInteger(as: UInt32.self).map { Float(bitPattern: $0) } diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index d4747163..e98ab1f1 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -1,6 +1,18 @@ import NIOCore struct PostgresFrontendMessageEncoder { + + /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5679 in the least significant 16 bits. + static let sslRequestCode: Int32 = 80877103 + + /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits, + /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same + /// as any protocol version number.) + static let cancelRequestCode: Int32 = 80877102 + + static let startupVersionThree: Int32 = 0x00_03_00_00 + private enum State { case flushed case writable @@ -15,8 +27,8 @@ struct PostgresFrontendMessageEncoder { mutating func startup(user: String, database: String?) { self.clearIfNeeded() - self.encodeLengthPrefixed { buffer in - buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree) + self.buffer.psqlLengthPrefixed { buffer in + buffer.writeInteger(Self.startupVersionThree) buffer.writeNullTerminatedString("user") buffer.writeNullTerminatedString(user) @@ -31,8 +43,7 @@ struct PostgresFrontendMessageEncoder { mutating func bind(portalName: String, preparedStatementName: String, bind: PostgresBindings) { self.clearIfNeeded() - self.buffer.psqlWriteFrontendMessageID(.bind) - self.encodeLengthPrefixed { buffer in + self.buffer.psqlLengthPrefixed(id: .bind) { buffer in buffer.writeNullTerminatedString(portalName) buffer.writeNullTerminatedString(preparedStatementName) @@ -65,45 +76,45 @@ struct PostgresFrontendMessageEncoder { mutating func cancel(processID: Int32, secretKey: Int32) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(UInt32(16), PostgresFrontendMessage.Cancel.requestCode, processID, secretKey) + self.buffer.writeMultipleIntegers(UInt32(16), Self.cancelRequestCode, processID, secretKey) } mutating func closePreparedStatement(_ preparedStatement: String) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.psqlWriteMultipleIntegers(id: .close, length: UInt32(2 + preparedStatement.utf8.count), UInt8(ascii: "S")) self.buffer.writeNullTerminatedString(preparedStatement) } mutating func closePortal(_ portal: String) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.psqlWriteMultipleIntegers(id: .close, length: UInt32(2 + portal.utf8.count), UInt8(ascii: "P")) self.buffer.writeNullTerminatedString(portal) } mutating func describePreparedStatement(_ preparedStatement: String) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S")) + self.buffer.psqlWriteMultipleIntegers(id: .describe, length: UInt32(2 + preparedStatement.utf8.count), UInt8(ascii: "S")) self.buffer.writeNullTerminatedString(preparedStatement) } mutating func describePortal(_ portal: String) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P")) + self.buffer.psqlWriteMultipleIntegers(id: .describe, length: UInt32(2 + portal.utf8.count), UInt8(ascii: "P")) self.buffer.writeNullTerminatedString(portal) } mutating func execute(portalName: String, maxNumberOfRows: Int32 = 0) { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.execute.rawValue, UInt32(9 + portalName.utf8.count)) + self.buffer.psqlWriteMultipleIntegers(id: .execute, length: UInt32(5 + portalName.utf8.count)) self.buffer.writeNullTerminatedString(portalName) self.buffer.writeInteger(maxNumberOfRows) } mutating func parse(preparedStatementName: String, query: String, parameters: Parameters) where Parameters.Element == PostgresDataType { self.clearIfNeeded() - self.buffer.writeMultipleIntegers( - PostgresFrontendMessage.ID.parse.rawValue, - UInt32(4 + preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout.size * parameters.count) + self.buffer.psqlWriteMultipleIntegers( + id: .parse, + length: UInt32(preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout.size * parameters.count) ) self.buffer.writeNullTerminatedString(preparedStatementName) self.buffer.writeNullTerminatedString(query) @@ -116,28 +127,25 @@ struct PostgresFrontendMessageEncoder { mutating func password(_ bytes: Bytes) where Bytes.Element == UInt8 { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.password.rawValue, UInt32(5 + bytes.count)) + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(bytes.count) + 1) self.buffer.writeBytes(bytes) self.buffer.writeInteger(UInt8(0)) } mutating func flush() { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.flush.rawValue, UInt32(4)) + self.buffer.psqlWriteMultipleIntegers(id: .flush, length: 0) } mutating func saslResponse(_ bytes: Bytes) where Bytes.Element == UInt8 { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.saslResponse.rawValue, UInt32(4 + bytes.count)) + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(bytes.count)) self.buffer.writeBytes(bytes) } mutating func saslInitialResponse(mechanism: String, bytes: Bytes) where Bytes.Element == UInt8 { self.clearIfNeeded() - self.buffer.writeMultipleIntegers( - PostgresFrontendMessage.ID.saslInitialResponse.rawValue, - UInt32(4 + mechanism.utf8.count + 1 + 4 + bytes.count) - ) + self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(mechanism.utf8.count + 1 + 4 + bytes.count)) self.buffer.writeNullTerminatedString(mechanism) if bytes.count > 0 { self.buffer.writeInteger(Int32(bytes.count)) @@ -149,17 +157,17 @@ struct PostgresFrontendMessageEncoder { mutating func ssl() { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(UInt32(8), PostgresFrontendMessage.SSLRequest.requestCode) + self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) } mutating func sync() { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.sync.rawValue, UInt32(4)) + self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) } mutating func terminate() { self.clearIfNeeded() - self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.terminate.rawValue, UInt32(4)) + self.buffer.psqlWriteMultipleIntegers(id: .terminate, length: 0) } mutating func flushBuffer() -> ByteBuffer { @@ -177,13 +185,42 @@ struct PostgresFrontendMessageEncoder { break } } +} - private mutating func encodeLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) { - let startIndex = self.buffer.writerIndex - self.buffer.writeInteger(UInt32(0)) // placeholder for length - encode(&self.buffer) - let length = UInt32(self.buffer.writerIndex - startIndex) - self.buffer.setInteger(length, at: startIndex) +private enum FrontendMessageID: UInt8, Hashable, Sendable { + case bind = 66 // B + case close = 67 // C + case describe = 68 // D + case execute = 69 // E + case flush = 72 // H + case parse = 80 // P + case password = 112 // p - also both sasl values + case sync = 83 // S + case terminate = 88 // X +} + +extension ByteBuffer { + mutating fileprivate func psqlWriteMultipleIntegers(id: FrontendMessageID, length: UInt32) { + self.writeMultipleIntegers(id.rawValue, 4 + length) + } + + mutating fileprivate func psqlWriteMultipleIntegers(id: FrontendMessageID, length: UInt32, _ t1: T1) { + self.writeMultipleIntegers(id.rawValue, 4 + length, t1) } + mutating fileprivate func psqlLengthPrefixed(id: FrontendMessageID, _ encode: (inout ByteBuffer) -> ()) { + let lengthIndex = self.writerIndex + 1 + self.psqlWriteMultipleIntegers(id: id, length: 0) + encode(&self) + let length = UInt32(self.writerIndex - lengthIndex) + self.setInteger(length, at: lengthIndex) + } + + mutating fileprivate func psqlLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) { + let lengthIndex = self.writerIndex + self.writeInteger(UInt32(0)) // placeholder + encode(&self) + let length = UInt32(self.writerIndex - lengthIndex) + self.setInteger(length, at: lengthIndex) + } } diff --git a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift index 71994596..7d073873 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift @@ -2,7 +2,10 @@ import NIOCore @testable import PostgresNIO extension ByteBuffer { - + mutating func psqlWriteBackendMessageID(_ messageID: PostgresBackendMessage.ID) { + self.writeInteger(messageID.rawValue) + } + static func backendMessage(id: PostgresBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows -> ByteBuffer { var byteBuffer = ByteBuffer() try byteBuffer.writeBackendMessage(id: id, payload) diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift similarity index 99% rename from Sources/PostgresNIO/New/PostgresFrontendMessage.swift rename to Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index ef7ce8f8..010667dc 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -1,4 +1,5 @@ import NIOCore +import PostgresNIO /// A wire message that is created by a Postgres client to be consumed by Postgres server. /// From 8f8557bfe6a3ca379da2cf84059acbdba1c3958f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sun, 20 Aug 2023 17:46:21 +0200 Subject: [PATCH 182/292] Remove PSQLError.Code.clientClosesConnection (#400) --- .../ConnectionStateMachine.swift | 6 +++--- Sources/PostgresNIO/New/PSQLError.swift | 14 ++------------ .../PostgresNIO/New/PostgresChannelHandler.swift | 6 ++++-- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 3 +-- .../New/PostgresChannelHandlerTests.swift | 3 +-- 5 files changed, 11 insertions(+), 21 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 22c4087e..eca251ff 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -550,7 +550,7 @@ struct ConnectionStateMachine { // check if we are quiescing. if so fail task immidiatly switch self.quiescingState { case .quiescing: - psqlErrror = PSQLError.clientClosesConnection(underlying: nil) + psqlErrror = PSQLError.clientClosedConnection(underlying: nil) case .notQuiescing: switch self.state { @@ -570,7 +570,7 @@ struct ConnectionStateMachine { return self.executeTask(task) case .closing(let error): - psqlErrror = PSQLError.clientClosesConnection(underlying: error) + psqlErrror = PSQLError.clientClosedConnection(underlying: error) case .closed(clientInitiated: true, error: let error): psqlErrror = PSQLError.clientClosedConnection(underlying: error) @@ -1033,7 +1033,7 @@ extension ConnectionStateMachine { } return false - case .clientClosesConnection, .clientClosedConnection: + case .clientClosedConnection: preconditionFailure("A pure client error was thrown directly in PostgresConnection, this shouldn't happen") case .serverClosedConnection: return true diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 1fec59b1..7060a690 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -18,7 +18,6 @@ public struct PSQLError: Error { case queryCancelled case tooManyParameters - case clientClosesConnection case clientClosedConnection case serverClosedConnection case connectionError @@ -46,7 +45,6 @@ public struct PSQLError: Error { public static let invalidCommandTag = Self(.invalidCommandTag) public static let queryCancelled = Self(.queryCancelled) public static let tooManyParameters = Self(.tooManyParameters) - public static let clientClosesConnection = Self(.clientClosesConnection) public static let clientClosedConnection = Self(.clientClosedConnection) public static let serverClosedConnection = Self(.serverClosedConnection) public static let connectionError = Self(.connectionError) @@ -54,8 +52,8 @@ public struct PSQLError: Error { public static let listenFailed = Self.init(.listenFailed) public static let unlistenFailed = Self.init(.unlistenFailed) - @available(*, deprecated, renamed: "clientClosesConnection") - public static let connectionQuiescing = Self.clientClosesConnection + @available(*, deprecated, renamed: "clientClosedConnection") + public static let connectionQuiescing = Self.clientClosedConnection @available(*, deprecated, message: "Use the more specific `serverClosedConnection` or `clientClosedConnection` instead") public static let connectionClosed = Self.serverClosedConnection @@ -86,8 +84,6 @@ public struct PSQLError: Error { return "queryCancelled" case .tooManyParameters: return "tooManyParameters" - case .clientClosesConnection: - return "clientClosesConnection" case .clientClosedConnection: return "clientClosedConnection" case .serverClosedConnection: @@ -387,12 +383,6 @@ public struct PSQLError: Error { return new } - static func clientClosesConnection(underlying: Error?) -> PSQLError { - var error = PSQLError(code: .clientClosesConnection) - error.underlying = underlying - return error - } - static func clientClosedConnection(underlying: Error?) -> PSQLError { var error = PSQLError(code: .clientClosedConnection) error.underlying = underlying diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 7b31a776..6d9d08b3 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -576,8 +576,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } // 3. fire an error - context.fireErrorCaught(cleanup.error) - + if cleanup.error.code != .clientClosedConnection { + context.fireErrorCaught(cleanup.error) + } + // 4. close the connection or fire channel inactive switch cleanup.action { case .close: diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 1989e5bc..c4f30624 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -37,8 +37,7 @@ extension PSQLError { return self.underlying ?? self case .tooManyParameters, .invalidCommandTag: return self - case .clientClosesConnection, - .clientClosedConnection, + case .clientClosedConnection, .serverClosedConnection: return PostgresError.connectionClosed case .connectionError: diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index b047cd72..b81d0899 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -25,8 +25,7 @@ class PostgresChannelHandlerTests: XCTestCase { handler ], loop: self.eventLoop) defer { - do { try embedded.finish() } - catch { print("\(String(reflecting: error))") } + XCTAssertNoThrow({ try embedded.finish() }) } var maybeMessage: PostgresFrontendMessage? From 689e4aabd783df4d8fb0eedee0787014a141f9e8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Aug 2023 17:12:10 +0200 Subject: [PATCH 183/292] Use variadic generics to decode rows in Swift 5.9 (#341) --- .../New/PostgresRow-multi-decode.swift | 2 + .../PostgresRowSequence-multi-decode.swift | 2 +- .../PostgresNIO/New/VariadicGenerics.swift | 174 ++++++++++++++++++ Tests/IntegrationTests/AsyncTests.swift | 7 +- Tests/IntegrationTests/PostgresNIOTests.swift | 8 +- .../New/PostgresRowSequenceTests.swift | 12 +- 6 files changed, 191 insertions(+), 14 deletions(-) create mode 100644 Sources/PostgresNIO/New/VariadicGenerics.swift diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift index cb62c325..71aa04dc 100644 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift @@ -1,5 +1,6 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh +#if compiler(<5.9) extension PostgresRow { @inlinable @_alwaysEmitIntoClient @@ -1171,3 +1172,4 @@ extension PostgresRow { try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) } } +#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift index 53d9a7ea..f45357d8 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift @@ -1,6 +1,6 @@ /// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh -#if canImport(_Concurrency) +#if compiler(<5.9) extension AsyncSequence where Element == PostgresRow { @inlinable @_alwaysEmitIntoClient diff --git a/Sources/PostgresNIO/New/VariadicGenerics.swift b/Sources/PostgresNIO/New/VariadicGenerics.swift new file mode 100644 index 00000000..312d36dc --- /dev/null +++ b/Sources/PostgresNIO/New/VariadicGenerics.swift @@ -0,0 +1,174 @@ +#if compiler(>=5.9) +extension PostgresRow { + // --- snip TODO: Remove once bug is fixed, that disallows tuples of one + @inlinable + public func decode( + _: Column.Type, + file: String = #fileID, + line: Int = #line + ) throws -> (Column) { + try self.decode(Column.self, context: .default, file: file, line: line) + } + + @inlinable + public func decode( + _: Column.Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) throws -> (Column) { + precondition(self.columns.count >= 1) + let columnIndex = 0 + var cellIterator = self.data.makeIterator() + var cellData = cellIterator.next().unsafelyUnwrapped + var columnIterator = self.columns.makeIterator() + let column = columnIterator.next().unsafelyUnwrapped + let swiftTargetType: Any.Type = Column.self + + do { + let r0 = try Column._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + + return (r0) + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: swiftTargetType, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + // --- snap TODO: Remove once bug is fixed, that disallows tuples of one + + @inlinable + public func decode( + _ columnType: (repeat each Column).Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) throws -> (repeat each Column) { + let packCount = ComputeParameterPackLength.count(ofPack: repeat (each Column).self) + precondition(self.columns.count >= packCount) + + var columnIndex = 0 + var cellIterator = self.data.makeIterator() + var columnIterator = self.columns.makeIterator() + + return ( + repeat try Self.decodeNextColumn( + (each Column).self, + cellIterator: &cellIterator, + columnIterator: &columnIterator, + columnIndex: &columnIndex, + context: context, + file: file, + line: line + ) + ) + } + + @inlinable + static func decodeNextColumn( + _ columnType: Column.Type, + cellIterator: inout IndexingIterator, + columnIterator: inout IndexingIterator<[RowDescription.Column]>, + columnIndex: inout Int, + context: PostgresDecodingContext, + file: String, + line: Int + ) throws -> Column { + defer { columnIndex += 1 } + + let column = columnIterator.next().unsafelyUnwrapped + var cellData = cellIterator.next().unsafelyUnwrapped + do { + return try Column._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) + } catch let code as PostgresDecodingError.Code { + throw PostgresDecodingError( + code: code, + columnName: column.name, + columnIndex: columnIndex, + targetType: Column.self, + postgresType: column.dataType, + postgresFormat: column.format, + postgresData: cellData, + file: file, + line: line + ) + } + } + + @inlinable + public func decode( + _ columnType: (repeat each Column).Type, + file: String = #fileID, + line: Int = #line + ) throws -> (repeat each Column) { + try self.decode(columnType, context: .default, file: file, line: line) + } +} + +extension AsyncSequence where Element == PostgresRow { + // --- snip TODO: Remove once bug is fixed, that disallows tuples of one + @inlinable + public func decode( + _: Column.Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode(Column.self, context: context, file: file, line: line) + } + } + + @inlinable + public func decode( + _: Column.Type, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.decode(Column.self, context: .default, file: file, line: line) + } + // --- snap TODO: Remove once bug is fixed, that disallows tuples of one + + public func decode( + _ columnType: (repeat each Column).Type, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.map { row in + try row.decode(columnType, context: context, file: file, line: line) + } + } + + public func decode( + _ columnType: (repeat each Column).Type, + file: String = #fileID, + line: Int = #line + ) -> AsyncThrowingMapSequence { + self.decode(columnType, context: .default, file: file, line: line) + } +} + +@usableFromInline +enum ComputeParameterPackLength { + @usableFromInline + enum BoolConverter { + @usableFromInline + typealias Bool = Swift.Bool + } + + @inlinable + static func count(ofPack t: repeat each T) -> Int { + MemoryLayout<(repeat BoolConverter.Bool)>.size / MemoryLayout.stride + } +} +#endif // compiler(>=5.9) + diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index bf945a67..5c77ba29 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -8,7 +8,6 @@ import NIOPosix import NIOCore final class AsyncPostgresConnectionTests: XCTestCase { - func test1kRoundTrips() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -37,7 +36,8 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await withTestConnection(on: eventLoop) { connection in let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) var counter = 0 - for try await element in rows.decode(Int.self, context: .default) { + for try await row in rows { + let element = try row.decode(Int.self) XCTAssertEqual(element, counter + 1) counter += 1 } @@ -259,7 +259,8 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await withTestConnection(on: eventLoop) { connection in let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest) var counter = 1 - for try await element in rows.decode(Int.self, context: .default) { + for try await row in rows { + let element = try row.decode(Int.self, context: .default) XCTAssertEqual(element, counter) counter += 1 } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 19c4e167..ea4d8d05 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1246,10 +1246,10 @@ final class PostgresNIOTests: XCTestCase { return EventLoopFuture.whenAllSucceed([a, b, c], on: self.eventLoop) }).wait()) XCTAssertEqual(queries?.count, 3) - var resutIterator = queries?.makeIterator() - XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "a") - XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "b") - XCTAssertEqual(try resutIterator?.next()?.first?.decode(String.self, context: .default), "c") + var resultIterator = queries?.makeIterator() + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "a") + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "b") + XCTAssertEqual(try resultIterator?.next()?.first?.decode(String.self, context: .default), "c") } // https://github.com/vapor/postgres-nio/issues/122 diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 872c098d..816daf04 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -59,7 +59,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .default), counter) + XCTAssertEqual(try row.decode(Int.self), counter) counter += 1 if counter == 64 { @@ -135,7 +135,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .default), counter) + XCTAssertEqual(try row.decode(Int.self), counter) counter += 1 } @@ -163,7 +163,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self, context: .default), counter) + XCTAssertEqual(try row.decode(Int.self), counter) counter += 1 } @@ -220,7 +220,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0) + XCTAssertEqual(try row1?.decode(Int.self), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { stream.receive(completion: .success("SELECT 1")) @@ -252,7 +252,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0) + XCTAssertEqual(try row1?.decode(Int.self), 0) DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) @@ -415,7 +415,7 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 1 for _ in 0..<(2 * messagePerChunk - 1) { let row = try await rowIterator.next() - XCTAssertEqual(try row?.decode(Int.self, context: .default), counter) + XCTAssertEqual(try row?.decode(Int.self), counter) counter += 1 } From 0d9f13be024047397c0f1bf72edf7ffd36cac67a Mon Sep 17 00:00:00 2001 From: Marius Seufzer <44228394+marius-se@users.noreply.github.com> Date: Sun, 27 Aug 2023 23:42:19 +1200 Subject: [PATCH 184/292] Add `PostgresDynamicTypeThrowingEncodable` and `PostgresDynamicTypeEncodable` (#365) --- .../New/Data/Array+PostgresCodable.swift | 7 ++ .../New/Data/Range+PostgresCodable.swift | 22 ++++++- Sources/PostgresNIO/New/PostgresCodable.swift | 65 +++++++++++++++---- Sources/PostgresNIO/New/PostgresQuery.swift | 22 +++---- .../New/PostgresQueryTests.swift | 37 +++++++++++ 5 files changed, 128 insertions(+), 25 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index fb2b62e3..d605a6c1 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -136,6 +136,10 @@ extension Array: PostgresEncodable where Element: PostgresArrayEncodable { } } +// explicitly conforming to PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Array: PostgresThrowingDynamicTypeEncodable where Element: PostgresArrayEncodable {} + extension Array: PostgresNonThrowingEncodable where Element: PostgresArrayEncodable & PostgresNonThrowingEncodable { public static var psqlType: PostgresDataType { Element.psqlArrayType @@ -173,6 +177,9 @@ extension Array: PostgresNonThrowingEncodable where Element: PostgresArrayEncoda } } +// explicitly conforming to PostgresDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Array: PostgresDynamicTypeEncodable where Element: PostgresArrayEncodable & PostgresNonThrowingEncodable {} extension Array: PostgresDecodable where Element: PostgresArrayDecodable, Element == Element._DecodableType { public init( diff --git a/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift index e5a3e60e..6279cf4b 100644 --- a/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Range+PostgresCodable.swift @@ -191,6 +191,11 @@ extension PostgresRange: PostgresEncodable & PostgresNonThrowingEncodable where } } +// explicitly conforming to PostgresDynamicTypeEncodable and PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension PostgresRange: PostgresThrowingDynamicTypeEncodable & PostgresDynamicTypeEncodable + where Bound: PostgresRangeEncodable {} + extension PostgresRange where Bound: Comparable { @inlinable init(range: Range) { @@ -227,6 +232,11 @@ extension Range: PostgresEncodable where Bound: PostgresRangeEncodable { extension Range: PostgresNonThrowingEncodable where Bound: PostgresRangeEncodable {} +// explicitly conforming to PostgresDynamicTypeEncodable and PostgresThrowingDynamicTypeEncodable because of: +// https://github.com/apple/swift/issues/54132 +extension Range: PostgresDynamicTypeEncodable & PostgresThrowingDynamicTypeEncodable + where Bound: PostgresRangeEncodable {} + extension Range: PostgresDecodable where Bound: PostgresRangeDecodable { @inlinable public init( @@ -249,7 +259,7 @@ extension Range: PostgresDecodable where Bound: PostgresRangeDecodable { else { throw PostgresDecodingError.Code.failure } - + self = lowerBound..( @@ -301,7 +319,7 @@ extension ClosedRange: PostgresDecodable where Bound: PostgresRangeDecodable { if lowerBound > upperBound { throw PostgresDecodingError.Code.failure } - + self = lowerBound...upperBound } } diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 36937de4..53dbd708 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -2,29 +2,62 @@ import NIOCore import class Foundation.JSONEncoder import class Foundation.JSONDecoder +/// A type that can encode itself to a Postgres wire binary representation. +/// Dynamic types are types that don't have a well-known Postgres type OID at compile time. +/// For example, custom types created at runtime, such as enums, or extension types whose OID is not stable between +/// databases. +public protocol PostgresThrowingDynamicTypeEncodable { + /// The data type encoded into the `byteBuffer` in ``encode(into:context:)`` + var psqlType: PostgresDataType { get } + + /// The Postgres encoding format used to encode the value into `byteBuffer` in ``encode(into:context:)``. + var psqlFormat: PostgresFormat { get } + + /// Encode the entity into ``byteBuffer`` in the format specified by ``psqlFormat``, + /// using the provided ``context`` as needed, without setting the byte count. + /// + /// This method is called by ``PostgresBindings``. + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) throws +} + +/// A type that can encode itself to a Postgres wire binary representation. +/// Dynamic types are types that don't have a well-known Postgres type OID at compile time. +/// For example, custom types created at runtime, such as enums, or extension types whose OID is not stable between +/// databases. +/// +/// This is the non-throwing alternative to ``PostgresThrowingDynamicTypeEncodable``. It allows users +/// to create ``PostgresQuery``s via `ExpressibleByStringInterpolation` without having to spell `try`. +public protocol PostgresDynamicTypeEncodable: PostgresThrowingDynamicTypeEncodable { + /// Encode the entity into ``byteBuffer`` in the format specified by ``psqlFormat``, + /// using the provided ``context`` as needed, without setting the byte count. + /// + /// This method is called by ``PostgresBindings``. + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresEncodingContext + ) +} + /// A type that can encode itself to a postgres wire binary representation. -public protocol PostgresEncodable { +public protocol PostgresEncodable: PostgresThrowingDynamicTypeEncodable { // TODO: Rename to `PostgresThrowingEncodable` with next major release - /// identifies the data type that we will encode into `byteBuffer` in `encode` + /// The data type encoded into the `byteBuffer` in ``encode(into:context:)``. static var psqlType: PostgresDataType { get } - /// identifies the postgres format that is used to encode the value into `byteBuffer` in `encode` + /// The Postgres encoding format used to encode the value into `byteBuffer` in ``encode(into:context:)``. static var psqlFormat: PostgresFormat { get } - - /// Encode the entity into the `byteBuffer` in Postgres binary format, without setting - /// the byte count. This method is called from the ``PostgresBindings``. - func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) throws } /// A type that can encode itself to a postgres wire binary representation. It enforces that the /// ``PostgresEncodable/encode(into:context:)-1jkcp`` does not throw. This allows users -/// to create ``PostgresQuery``s using the `ExpressibleByStringInterpolation` without +/// to create ``PostgresQuery``s via `ExpressibleByStringInterpolation` without /// having to spell `try`. -public protocol PostgresNonThrowingEncodable: PostgresEncodable { +public protocol PostgresNonThrowingEncodable: PostgresEncodable, PostgresDynamicTypeEncodable { // TODO: Rename to `PostgresEncodable` with next major release - - func encode(into byteBuffer: inout ByteBuffer, context: PostgresEncodingContext) } /// A type that can decode itself from a postgres wire binary representation. @@ -84,6 +117,14 @@ extension PostgresDecodable { public typealias PostgresCodable = PostgresEncodable & PostgresDecodable extension PostgresEncodable { + @inlinable + public var psqlType: PostgresDataType { Self.psqlType } + + @inlinable + public var psqlFormat: PostgresFormat { Self.psqlFormat } +} + +extension PostgresThrowingDynamicTypeEncodable { @inlinable func encodeRaw( into buffer: inout ByteBuffer, @@ -103,7 +144,7 @@ extension PostgresEncodable { } } -extension PostgresNonThrowingEncodable { +extension PostgresDynamicTypeEncodable { @inlinable func encodeRaw( into buffer: inout ByteBuffer, diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 4ca1e454..1cfcf2dc 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -44,13 +44,13 @@ extension PostgresQuery { } @inlinable - public mutating func appendInterpolation(_ value: Value) throws { + public mutating func appendInterpolation(_ value: Value) throws { try self.binds.append(value, context: .default) self.sql.append(contentsOf: "$\(self.binds.count)") } @inlinable - public mutating func appendInterpolation(_ value: Optional) throws { + public mutating func appendInterpolation(_ value: Optional) throws { switch value { case .none: self.binds.appendNull() @@ -62,13 +62,13 @@ extension PostgresQuery { } @inlinable - public mutating func appendInterpolation(_ value: Value) { + public mutating func appendInterpolation(_ value: Value) { self.binds.append(value, context: .default) self.sql.append(contentsOf: "$\(self.binds.count)") } @inlinable - public mutating func appendInterpolation(_ value: Optional) { + public mutating func appendInterpolation(_ value: Optional) { switch value { case .none: self.binds.appendNull() @@ -80,7 +80,7 @@ extension PostgresQuery { } @inlinable - public mutating func appendInterpolation( + public mutating func appendInterpolation( _ value: Value, context: PostgresEncodingContext ) throws { @@ -136,8 +136,8 @@ public struct PostgresBindings: Sendable, Hashable { } @inlinable - init(value: Value, protected: Bool) { - self.init(dataType: Value.psqlType, format: Value.psqlFormat, protected: protected) + init(value: Value, protected: Bool) { + self.init(dataType: value.psqlType, format: value.psqlFormat, protected: protected) } } @@ -168,12 +168,12 @@ public struct PostgresBindings: Sendable, Hashable { } @inlinable - public mutating func append(_ value: Value) throws { + public mutating func append(_ value: Value) throws { try self.append(value, context: .default) } @inlinable - public mutating func append( + public mutating func append( _ value: Value, context: PostgresEncodingContext ) throws { @@ -182,12 +182,12 @@ public struct PostgresBindings: Sendable, Hashable { } @inlinable - public mutating func append(_ value: Value) { + public mutating func append(_ value: Value) { self.append(value, context: .default) } @inlinable - public mutating func append( + public mutating func append( _ value: Value, context: PostgresEncodingContext ) { diff --git a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift index f50d414a..4930f0c4 100644 --- a/Tests/PostgresNIOTests/New/PostgresQueryTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresQueryTests.swift @@ -31,6 +31,27 @@ final class PostgresQueryTests: XCTestCase { XCTAssertEqual(query.binds.bytes, expected) } + func testStringInterpolationWithDynamicType() { + let type = PostgresDataType(16435) + let format = PostgresFormat.binary + let dynamicString = DynamicString(value: "Hello world", psqlType: type, psqlFormat: format) + + let query: PostgresQuery = """ + INSERT INTO foo (dynamicType) SET (\(dynamicString)); + """ + + XCTAssertEqual(query.sql, "INSERT INTO foo (dynamicType) SET ($1);") + + var expectedBindsBytes = ByteBuffer() + expectedBindsBytes.writeInteger(Int32(dynamicString.value.utf8.count)) + expectedBindsBytes.writeString(dynamicString.value) + + let expectedMetadata: [PostgresBindings.Metadata] = [.init(dataType: type, format: format, protected: true)] + + XCTAssertEqual(query.binds.bytes, expectedBindsBytes) + XCTAssertEqual(query.binds.metadata, expectedMetadata) + } + func testStringInterpolationWithCustomJSONEncoder() { struct Foo: Codable, PostgresCodable { var helloWorld: String @@ -89,3 +110,19 @@ final class PostgresQueryTests: XCTestCase { XCTAssertEqual(query.binds.bytes, expected) } } + +extension PostgresQueryTests { + struct DynamicString: PostgresDynamicTypeEncodable { + let value: String + + var psqlType: PostgresDataType + var psqlFormat: PostgresFormat + + func encode( + into byteBuffer: inout ByteBuffer, + context: PostgresNIO.PostgresEncodingContext + ) where JSONEncoder: PostgresJSONEncoder { + byteBuffer.writeString(value) + } + } +} From d89a72304d2cf847f115773467432ce955e43981 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 28 Aug 2023 03:38:11 -0500 Subject: [PATCH 185/292] Improve the logo image used by the DocC catalog (#404) --- .../Docs.docc/images/vapor-postgres-logo.svg | 37 +++++++++++-------- .../PostgresNIO/Docs.docc/theme-settings.json | 6 +-- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg index e1c1223b..d118faab 100644 --- a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg @@ -2,35 +2,40 @@ - - - - - - - - - - - - + PostgresNIO + + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index c6ce054e..e9fc3d9d 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -22,14 +22,14 @@ "light": "rgb(255, 255, 255)" }, "psql-blue": "#336791", - "documentation-intro-fill": "radial-gradient(circle at top, var(--color-documentation-intro-accent) 30%, #1f1d1f 100%)", + "documentation-intro-fill": "radial-gradient(circle at top, var(--color-documentation-intro-accent) 30%, #000 100%)", "documentation-intro-accent": "var(--color-psql-blue)", "documentation-intro-accent-outer": { "dark": "rgb(255, 255, 255)", - "light": "rgb(51, 51, 51)" + "light": "rgb(0, 0, 0)" }, "documentation-intro-accent-inner": { - "dark": "rgb(51, 51, 51)", + "dark": "rgb(0, 0, 0)", "light": "rgb(255, 255, 255)" } }, From abca6b390235ae337999d367c40cc40c99629385 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 29 Aug 2023 18:09:29 +0200 Subject: [PATCH 186/292] Fix Segmentation faults in Swift 5.8 (#406) --- .../ConnectionStateMachine.swift | 292 ++++++++---------- 1 file changed, 122 insertions(+), 170 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index eca251ff..125d26bb 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -333,11 +333,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.authentication(message))) } - return self.avoidingStateMachineCoW { machine in - let action = authState.authenticationMessageReceived(message) - machine.state = .authenticating(authState) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = authState.authenticationMessageReceived(message) + self.state = .authenticating(authState) + return self.modify(with: action) } mutating func backendKeyDataReceived(_ keyData: PostgresBackendMessage.BackendKeyData) -> ConnectionAction { @@ -363,29 +362,29 @@ struct ConnectionStateMachine { .closing: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterStatus(status))) case .authenticated(let keyData, var parameters): - return self.avoidingStateMachineCoW { machine in - parameters[status.parameter] = status.value - machine.state = .authenticated(keyData, parameters) - return .wait - } + self.state = .modifying // avoid CoW + parameters[status.parameter] = status.value + self.state = .authenticated(keyData, parameters) + return .wait + case .readyForQuery(var connectionContext): - return self.avoidingStateMachineCoW { machine in - connectionContext.parameters[status.parameter] = status.value - machine.state = .readyForQuery(connectionContext) - return .wait - } + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .readyForQuery(connectionContext) + return .wait + case .extendedQuery(let query, var connectionContext): - return self.avoidingStateMachineCoW { machine in - connectionContext.parameters[status.parameter] = status.value - machine.state = .extendedQuery(query, connectionContext) - return .wait - } + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .extendedQuery(query, connectionContext) + return .wait + case .closeCommand(let closeState, var connectionContext): - return self.avoidingStateMachineCoW { machine in - connectionContext.parameters[status.parameter] = status.value - machine.state = .closeCommand(closeState, connectionContext) - return .wait - } + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .closeCommand(closeState, connectionContext) + return .wait + case .initialized, .closed: preconditionFailure("We shouldn't receive messages if we are not connected") @@ -407,29 +406,29 @@ struct ConnectionStateMachine { if authState.isComplete { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = authState.errorReceived(errorMessage) - machine.state = .authenticating(authState) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = authState.errorReceived(errorMessage) + self.state = .authenticating(authState) + return self.modify(with: action) + case .closeCommand(var closeStateMachine, let connectionContext): if closeStateMachine.isComplete { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = closeStateMachine.errorReceived(errorMessage) - machine.state = .closeCommand(closeStateMachine, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = closeStateMachine.errorReceived(errorMessage) + self.state = .closeCommand(closeStateMachine, connectionContext) + return self.modify(with: action) + case .extendedQuery(var extendedQueryState, let connectionContext): if extendedQueryState.isComplete { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = extendedQueryState.errorReceived(errorMessage) - machine.state = .extendedQuery(extendedQueryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = extendedQueryState.errorReceived(errorMessage) + self.state = .extendedQuery(extendedQueryState, connectionContext) + return self.modify(with: action) + case .closing: // If the state machine is in state `.closing`, the connection shutdown was initiated // by the client. This means a `TERMINATE` message has already been sent and the @@ -492,11 +491,11 @@ struct ConnectionStateMachine { mutating func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) -> ConnectionAction { switch self.state { case .extendedQuery(var extendedQuery, let connectionContext): - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = extendedQuery.noticeReceived(notice) - machine.state = .extendedQuery(extendedQuery, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = extendedQuery.noticeReceived(notice) + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + default: return .wait } @@ -612,11 +611,10 @@ struct ConnectionStateMachine { return .wait case .extendedQuery(var extendedQuery, let connectionContext): - return self.avoidingStateMachineCoW { machine in - let action = extendedQuery.channelReadComplete() - machine.state = .extendedQuery(extendedQuery, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = extendedQuery.channelReadComplete() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) case .modifying: preconditionFailure("Invalid state") @@ -642,17 +640,17 @@ struct ConnectionStateMachine { case .readyForQuery: return .read case .extendedQuery(var extendedQuery, let connectionContext): - return self.avoidingStateMachineCoW { machine in - let action = extendedQuery.readEventCaught() - machine.state = .extendedQuery(extendedQuery, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = extendedQuery.readEventCaught() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + case .closeCommand(var closeState, let connectionContext): - return self.avoidingStateMachineCoW { machine in - let action = closeState.readEventCaught() - machine.state = .closeCommand(closeState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = closeState.readEventCaught() + self.state = .closeCommand(closeState, connectionContext) + return self.modify(with: action) + case .closing: return .read case .closed: @@ -667,11 +665,11 @@ struct ConnectionStateMachine { mutating func parseCompleteReceived() -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.parseCompletedReceived() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.parseCompletedReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parseComplete)) } @@ -682,21 +680,20 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.bindComplete)) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.bindCompleteReceived() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.bindCompleteReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func parameterDescriptionReceived(_ description: PostgresBackendMessage.ParameterDescription) -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.parameterDescriptionReceived(description) - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.parameterDescriptionReceived(description) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.parameterDescription(description))) } @@ -705,11 +702,11 @@ struct ConnectionStateMachine { mutating func rowDescriptionReceived(_ description: RowDescription) -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.rowDescriptionReceived(description) - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.rowDescriptionReceived(description) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.rowDescription(description))) } @@ -718,11 +715,11 @@ struct ConnectionStateMachine { mutating func noDataReceived() -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.noDataReceived() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.noDataReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.noData)) } @@ -737,11 +734,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.closeComplete)) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = closeState.closeCompletedReceived() - machine.state = .closeCommand(closeState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = closeState.closeCompletedReceived() + self.state = .closeCommand(closeState, connectionContext) + return self.modify(with: action) } mutating func commandCompletedReceived(_ commandTag: String) -> ConnectionAction { @@ -749,11 +745,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.commandComplete(commandTag))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.commandCompletedReceived(commandTag) - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.commandCompletedReceived(commandTag) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func emptyQueryResponseReceived() -> ConnectionAction { @@ -761,11 +756,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.emptyQueryResponseReceived() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.emptyQueryResponseReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func dataRowReceived(_ dataRow: DataRow) -> ConnectionAction { @@ -773,11 +767,10 @@ struct ConnectionStateMachine { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.dataRow(dataRow))) } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.dataRowReceived(dataRow) - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.dataRowReceived(dataRow) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } // MARK: Consumer @@ -787,11 +780,10 @@ struct ConnectionStateMachine { preconditionFailure("Tried to cancel stream without active query") } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.cancel() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.cancel() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func requestQueryRows() -> ConnectionAction { @@ -799,11 +791,10 @@ struct ConnectionStateMachine { preconditionFailure("Tried to consume next row, without active query") } - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - let action = queryState.requestQueryRows() - machine.state = .extendedQuery(queryState, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + let action = queryState.requestQueryRows() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } // MARK: - Private Methods - @@ -813,12 +804,11 @@ struct ConnectionStateMachine { preconditionFailure("Can only start authentication after connect or ssl establish") } - return self.avoidingStateMachineCoW { machine in - var authState = AuthenticationStateMachine(authContext: authContext) - let action = authState.start() - machine.state = .authenticating(authState) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + var authState = AuthenticationStateMachine(authContext: authContext) + let action = authState.start() + self.state = .authenticating(authState) + return self.modify(with: action) } private mutating func closeConnectionAndCleanup(_ error: PSQLError, closePromise: EventLoopPromise? = nil) -> ConnectionAction { @@ -944,19 +934,18 @@ struct ConnectionStateMachine { switch task { case .extendedQuery(let queryContext): - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext) - let action = extendedQuery.start() - machine.state = .extendedQuery(extendedQuery, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext) + let action = extendedQuery.start() + self.state = .extendedQuery(extendedQuery, connectionContext) + return self.modify(with: action) + case .closeCommand(let closeContext): - return self.avoidingStateMachineCoW { machine -> ConnectionAction in - var closeStateMachine = CloseStateMachine(closeContext: closeContext) - let action = closeStateMachine.start() - machine.state = .closeCommand(closeStateMachine, connectionContext) - return machine.modify(with: action) - } + self.state = .modifying // avoid CoW + var closeStateMachine = CloseStateMachine(closeContext: closeContext) + let action = closeStateMachine.start() + self.state = .closeCommand(closeStateMachine, connectionContext) + return self.modify(with: action) } } @@ -965,43 +954,6 @@ struct ConnectionStateMachine { } } -// MARK: CoW helpers - -extension ConnectionStateMachine { - /// So, uh...this function needs some explaining. - /// - /// While the state machine logic above is great, there is a downside to having all of the state machine data in - /// associated data on enumerations: any modification of that data will trigger copy on write for heap-allocated - /// data. That means that for _every operation on the state machine_ we will CoW our underlying state, which is - /// not good. - /// - /// The way we can avoid this is by using this helper function. It will temporarily set state to a value with no - /// associated data, before attempting the body of the function. It will also verify that the state machine never - /// remains in this bad state. - /// - /// A key note here is that all callers must ensure that they return to a good state before they exit. - /// - /// Sadly, because it's generic and has a closure, we need to force it to be inlined at all call sites, which is - /// not ideal. - @inline(__always) - private mutating func avoidingStateMachineCoW(_ body: (inout ConnectionStateMachine) -> ReturnType) -> ReturnType { - self.state = .modifying - defer { - assert(!self.isModifying) - } - - return body(&self) - } - - private var isModifying: Bool { - if case .modifying = self.state { - return true - } else { - return false - } - } -} - extension ConnectionStateMachine { func shouldCloseConnection(reason error: PSQLError) -> Bool { switch error.code.base { From 92ee156a649b88f8926bcad6056cf77126b90405 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 18 Sep 2023 21:44:17 +0200 Subject: [PATCH 187/292] Update SSWG Graduation Level (#409) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b4f8f70e..2123262f 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@

-SSWG Incubation +SSWG Incubation Level: Graduated Documentation MIT License Continuous Integration From 4ab6d0aa7ac71f74f9d69094786a6d9e447b5722 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 12 Oct 2023 15:21:42 -0500 Subject: [PATCH 188/292] Update minimum Swift requirement to 5.7 (#414) Bump required Swift to 5.7, update dependency version requirements, update CI for Swift and Postgres versions, do some interesting things with the API docs and README. --- .github/workflows/test.yml | 49 +++++++++--------- Package.swift | 16 +++--- README.md | 22 +++++--- .../Docs.docc/images/vapor-postgres-logo.svg | 51 +++++++++++++------ .../PostgresNIO/Docs.docc/theme-settings.json | 2 +- docker-compose.yml | 3 ++ 6 files changed, 89 insertions(+), 54 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2da05f81..91895532 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,13 +18,13 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.6-focal - swift:5.7-jammy - swift:5.8-jammy - - swiftlang/swift:nightly-5.9-jammy + - swift:5.9-jammy + - swiftlang/swift:nightly-5.10-jammy - swiftlang/swift:nightly-main-jammy include: - - swift-image: swift:5.8-jammy + - swift-image: swift:5.9-jammy code-coverage: true container: ${{ matrix.swift-image }} runs-on: ubuntu-latest @@ -37,7 +37,7 @@ jobs: printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" swift --version - name: Check out package - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run unit tests with Thread Sanitizer env: CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} @@ -53,18 +53,18 @@ jobs: fail-fast: false matrix: postgres-image: - - postgres:15 - - postgres:13 - - postgres:11 + - postgres:16 + - postgres:14 + - postgres:12 include: - - postgres-image: postgres:15 + - postgres-image: postgres:16 postgres-auth: scram-sha-256 - - postgres-image: postgres:13 + - postgres-image: postgres:14 postgres-auth: md5 - - postgres-image: postgres:11 + - postgres-image: postgres:12 postgres-auth: trust container: - image: swift:5.8-jammy + image: swift:5.9-jammy volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: @@ -109,15 +109,15 @@ jobs: [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: { path: 'postgres-nio' } - name: Run integration tests run: swift test --package-path postgres-nio --filter=^IntegrationTests - name: Check out postgres-kit dependent - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } - name: Check out fluent-postgres-driver dependent - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: { repository: 'vapor/fluent-postgres-driver', path: 'fluent-postgres-driver' } - name: Use local package in dependents run: | @@ -135,13 +135,13 @@ jobs: matrix: postgres-formula: # Only test one version on macOS, let Linux do the rest - - postgresql@14 + - postgresql@15 postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 xcode-version: - '~14.3' - - '15.0-beta' + - '~15.0' runs-on: macos-13 env: POSTGRES_HOSTNAME: 127.0.0.1 @@ -164,7 +164,7 @@ jobs: pg_ctl start --wait timeout-minutes: 2 - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run all tests run: swift test @@ -174,21 +174,24 @@ jobs: container: swift:jammy steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 # https://github.com/actions/checkout/issues/766 - - name: Mark the workspace as safe - run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - name: API breaking changes - run: swift package diagnose-api-breaking-changes origin/main + run: | + git config --global --add safe.directory "${GITHUB_WORKSPACE}" + swift package diagnose-api-breaking-changes origin/main gh-codeql: runs-on: ubuntu-latest - permissions: { security-events: write } + container: swift:5.8-jammy # CodeQL currently broken with 5.9 + permissions: { actions: write, contents: read, security-events: write } steps: - name: Check out code - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - name: Mark repo safe in non-fake global config + run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - name: Initialize CodeQL uses: github/codeql-action/init@v2 with: diff --git a/Package.swift b/Package.swift index a45925ed..b3ff085c 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.6 +// swift-tools-version:5.7 import PackageDescription let package = Package( @@ -13,13 +13,13 @@ let package = Package( .library(name: "PostgresNIO", targets: ["PostgresNIO"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.58.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.18.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.23.1"), - .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), - .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.5.2"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.59.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), + .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), + .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), ], targets: [ .target( diff --git a/README.md b/README.md index 2123262f..bca6e82a 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,21 @@

-SSWG Incubation Level: Graduated -Documentation -MIT License -Continuous Integration -Swift 5.6 + + Documentation + + + MIT License + + + Continuous Integration + + + Swift 5.7 - 5.9 + + + SSWG Incubation Level: Graduated +


🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. @@ -170,7 +180,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.6]: https://swift.org +[Swift 5.7]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/ diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg index d118faab..2b3fe0b1 100644 --- a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg @@ -22,20 +22,39 @@ } PostgresNIO - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index e9fc3d9d..a8042a54 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -18,7 +18,7 @@ }, "color": { "fill": { - "dark": "rgb(20, 20, 22)", + "dark": "rgb(0, 0, 0)", "light": "rgb(255, 255, 255)" }, "psql-blue": "#336791", diff --git a/docker-compose.yml b/docker-compose.yml index 68797651..3eff4249 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,9 @@ x-shared-config: &shared_config - 5432:5432 services: + psql-16: + image: postgres:16 + <<: *shared_config psql-15: image: postgres:15 <<: *shared_config From 1a76cdc6dc9ba9a967b79a3593ec30ce34669f29 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 12 Oct 2023 17:42:24 -0500 Subject: [PATCH 189/292] [skip ci] Fix up README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bca6e82a..489d0e29 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@

- Documentation + Documentation MIT License @@ -16,10 +16,10 @@ Continuous Integration - Swift 5.7 - 5.9 + Swift 5.7 - 5.9 - SSWG Incubation Level: Graduated + SSWG Incubation Level: Graduated


From d4d7bed0fde77934a829daed5113f95ceaa7aba0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 13 Oct 2023 09:02:18 +0200 Subject: [PATCH 190/292] Add target `ConnectionPoolModule` (#412) Add `ConnectionPoolModule` We want to land a new ConnectionPool into PostgresNIO in the comming weeks. Since this pool is abstract, let's create a target and product for it. The target and product are both underscored, to signal that we don't make any API stability guarantees. --- Package.swift | 21 +++++++++++++++++++ Sources/ConnectionPoolModule/gitkeep.swift | 1 + Tests/ConnectionPoolModuleTests/gitkeep.swift | 1 + 3 files changed, 23 insertions(+) create mode 100644 Sources/ConnectionPoolModule/gitkeep.swift create mode 100644 Tests/ConnectionPoolModuleTests/gitkeep.swift diff --git a/Package.swift b/Package.swift index b3ff085c..814335bd 100644 --- a/Package.swift +++ b/Package.swift @@ -11,9 +11,11 @@ let package = Package( ], products: [ .library(name: "PostgresNIO", targets: ["PostgresNIO"]), + .library(name: "_ConnectionPoolModule", targets: ["_ConnectionPoolModule"]), ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), + .package(url: "https://github.com/apple/swift-collections.git", from: "1.0.4"), .package(url: "https://github.com/apple/swift-nio.git", from: "2.59.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), @@ -25,6 +27,7 @@ let package = Package( .target( name: "PostgresNIO", dependencies: [ + .target(name: "_ConnectionPoolModule"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Crypto", package: "swift-crypto"), .product(name: "Logging", package: "swift-log"), @@ -38,6 +41,14 @@ let package = Package( .product(name: "NIOFoundationCompat", package: "swift-nio"), ] ), + .target( + name: "_ConnectionPoolModule", + dependencies: [ + .product(name: "Atomics", package: "swift-atomics"), + .product(name: "DequeModule", package: "swift-collections"), + ], + path: "Sources/ConnectionPoolModule" + ), .testTarget( name: "PostgresNIOTests", dependencies: [ @@ -46,6 +57,16 @@ let package = Package( .product(name: "NIOTestUtils", package: "swift-nio"), ] ), + .testTarget( + name: "ConnectionPoolModuleTests", + dependencies: [ + .target(name: "_ConnectionPoolModule"), + .product(name: "DequeModule", package: "swift-collections"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + .product(name: "NIOEmbedded", package: "swift-nio"), + ] + ), .testTarget( name: "IntegrationTests", dependencies: [ diff --git a/Sources/ConnectionPoolModule/gitkeep.swift b/Sources/ConnectionPoolModule/gitkeep.swift new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Sources/ConnectionPoolModule/gitkeep.swift @@ -0,0 +1 @@ + diff --git a/Tests/ConnectionPoolModuleTests/gitkeep.swift b/Tests/ConnectionPoolModuleTests/gitkeep.swift new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/gitkeep.swift @@ -0,0 +1 @@ + From d6d3510c7053246de7a673d999bd0ed6f23fe468 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 13 Oct 2023 07:50:00 -0500 Subject: [PATCH 191/292] Fix test filter in CI --- .github/workflows/test.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 91895532..cc34ddcd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,13 +42,12 @@ jobs: env: CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} run: | - swift test --filter=^PostgresNIOTests --sanitize=thread ${CODE_COVERAGE} + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} - name: Submit code coverage if: ${{ matrix.code-coverage }} uses: vapor/swift-codecov-action@v0.2 linux-integration-and-dependencies: - if: github.event_name == 'pull_request' strategy: fail-fast: false matrix: @@ -129,7 +128,6 @@ jobs: run: swift test --package-path fluent-postgres-driver macos-all: - if: github.event_name == 'pull_request' strategy: fail-fast: false matrix: From c6c28a6df558dabc338aa1c42a77de28a40d43b7 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 13 Oct 2023 15:58:01 +0200 Subject: [PATCH 192/292] Vendor SwiftNIO NIOLock into the new `ConnectionPoolModule` target (#416) The new `ConnectionPoolModule` shall be dependency free. But we need a lock. Let's vendor NIOLock from SwiftNIO. --- NOTICE.txt | 9 +- Sources/ConnectionPoolModule/NIOLock.swift | 268 +++++++++++++++++++++ Sources/ConnectionPoolModule/gitkeep.swift | 1 - 3 files changed, 276 insertions(+), 2 deletions(-) create mode 100644 Sources/ConnectionPoolModule/NIOLock.swift delete mode 100644 Sources/ConnectionPoolModule/gitkeep.swift diff --git a/NOTICE.txt b/NOTICE.txt index 9547a780..e704f7e6 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -2,7 +2,7 @@ // // This source file is part of the Vapor open source project // -// Copyright (c) 2017-2021 Vapor project authors +// Copyright (c) 2017-2023 Vapor project authors // Licensed under MIT // // See LICENSE for license information @@ -11,3 +11,10 @@ // //===----------------------------------------------------------------------===// +This product contains a derivation of the NIOLock implementation +from Swift NIO. + + * LICENSE (Apache License 2.0): + * https://www.apache.org/licenses/LICENSE-2.0 + * HOMEPAGE: + * https://github.com/apple/swift-nio diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift new file mode 100644 index 00000000..dbc7dbe9 --- /dev/null +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -0,0 +1,268 @@ +// Implementation vendored from SwiftNIO: +// https://github.com/apple/swift-nio + +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Darwin) +import Darwin +#elseif os(Windows) +import ucrt +import WinSDK +#elseif canImport(Glibc) +import Glibc +#elseif canImport(Musl) +import Musl +#else +#error("The concurrency NIOLock module was unable to identify your C library.") +#endif + +#if os(Windows) +@usableFromInline +typealias LockPrimitive = SRWLOCK +#else +@usableFromInline +typealias LockPrimitive = pthread_mutex_t +#endif + +@usableFromInline +enum LockOperations { } + +extension LockOperations { + @inlinable + static func create(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + +#if os(Windows) + InitializeSRWLock(mutex) +#else + var attr = pthread_mutexattr_t() + pthread_mutexattr_init(&attr) + debugOnly { + pthread_mutexattr_settype(&attr, .init(PTHREAD_MUTEX_ERRORCHECK)) + } + + let err = pthread_mutex_init(mutex, &attr) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") +#endif + } + + @inlinable + static func destroy(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + +#if os(Windows) + // SRWLOCK does not need to be free'd +#else + let err = pthread_mutex_destroy(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") +#endif + } + + @inlinable + static func lock(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + +#if os(Windows) + AcquireSRWLockExclusive(mutex) +#else + let err = pthread_mutex_lock(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") +#endif + } + + @inlinable + static func unlock(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + +#if os(Windows) + ReleaseSRWLockExclusive(mutex) +#else + let err = pthread_mutex_unlock(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") +#endif + } +} + +// Tail allocate both the mutex and a generic value using ManagedBuffer. +// Both the header pointer and the elements pointer are stable for +// the class's entire lifetime. +// +// However, for safety reasons, we elect to place the lock in the "elements" +// section of the buffer instead of the head. The reasoning here is subtle, +// so buckle in. +// +// _As a practical matter_, the implementation of ManagedBuffer ensures that +// the pointer to the header is stable across the lifetime of the class, and so +// each time you call `withUnsafeMutablePointers` or `withUnsafeMutablePointerToHeader` +// the value of the header pointer will be the same. This is because ManagedBuffer uses +// `Builtin.addressOf` to load the value of the header, and that does ~magic~ to ensure +// that it does not invoke any weird Swift accessors that might copy the value. +// +// _However_, the header is also available via the `.header` field on the ManagedBuffer. +// This presents a problem! The reason there's an issue is that `Builtin.addressOf` and friends +// do not interact with Swift's exclusivity model. That is, the various `with` functions do not +// conceptually trigger a mutating access to `.header`. For elements this isn't a concern because +// there's literally no other way to perform the access, but for `.header` it's entirely possible +// to accidentally recursively read it. +// +// Our implementation is free from these issues, so we don't _really_ need to worry about it. +// However, out of an abundance of caution, we store the Value in the header, and the LockPrimitive +// in the trailing elements. We still don't use `.header`, but it's better to be safe than sorry, +// and future maintainers will be happier that we were cautious. +// +// See also: https://github.com/apple/swift/pull/40000 +@usableFromInline +final class LockStorage: ManagedBuffer { + + @inlinable + static func create(value: Value) -> Self { + let buffer = Self.create(minimumCapacity: 1) { _ in + return value + } + let storage = unsafeDowncast(buffer, to: Self.self) + + storage.withUnsafeMutablePointers { _, lockPtr in + LockOperations.create(lockPtr) + } + + return storage + } + + @inlinable + func lock() { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.lock(lockPtr) + } + } + + @inlinable + func unlock() { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.unlock(lockPtr) + } + } + + @inlinable + deinit { + self.withUnsafeMutablePointerToElements { lockPtr in + LockOperations.destroy(lockPtr) + } + } + + @inlinable + func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { + try self.withUnsafeMutablePointerToElements { lockPtr in + return try body(lockPtr) + } + } + + @inlinable + func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { + try self.withUnsafeMutablePointers { valuePtr, lockPtr in + LockOperations.lock(lockPtr) + defer { LockOperations.unlock(lockPtr) } + return try mutate(&valuePtr.pointee) + } + } +} + +extension LockStorage: @unchecked Sendable { } + +/// A threading lock based on `libpthread` instead of `libdispatch`. +/// +/// - note: ``NIOLock`` has reference semantics. +/// +/// This object provides a lock on top of a single `pthread_mutex_t`. This kind +/// of lock is safe to use with `libpthread`-based threading models, such as the +/// one used by NIO. On Windows, the lock is based on the substantially similar +/// `SRWLOCK` type. +@usableFromInline +struct NIOLock { + @usableFromInline + internal let _storage: LockStorage + + /// Create a new lock. + @inlinable + init() { + self._storage = .create(value: ()) + } + + /// Acquire the lock. + /// + /// Whenever possible, consider using `withLock` instead of this method and + /// `unlock`, to simplify lock handling. + @inlinable + func lock() { + self._storage.lock() + } + + /// Release the lock. + /// + /// Whenever possible, consider using `withLock` instead of this method and + /// `lock`, to simplify lock handling. + @inlinable + func unlock() { + self._storage.unlock() + } + + @inlinable + internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { + return try self._storage.withLockPrimitive(body) + } +} + +extension NIOLock { + /// Acquire the lock for the duration of the given block. + /// + /// This convenience method should be preferred to `lock` and `unlock` in + /// most situations, as it ensures that the lock will be released regardless + /// of how `body` exits. + /// + /// - Parameter body: The block to execute while holding the lock. + /// - Returns: The value returned by the block. + @inlinable + func withLock(_ body: () throws -> T) rethrows -> T { + self.lock() + defer { + self.unlock() + } + return try body() + } + + @inlinable + func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + try self.withLock(body) + } +} + +extension NIOLock: Sendable {} + +extension UnsafeMutablePointer { + @inlinable + func assertValidAlignment() { + assert(UInt(bitPattern: self) % UInt(MemoryLayout.alignment) == 0) + } +} + +/// A utility function that runs the body code only in debug builds, without +/// emitting compiler warnings. +/// +/// This is currently the only way to do this in Swift: see +/// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. +@inlinable +internal func debugOnly(_ body: () -> Void) { + // FIXME: duplicated with NIO. + assert({ body(); return true }()) +} diff --git a/Sources/ConnectionPoolModule/gitkeep.swift b/Sources/ConnectionPoolModule/gitkeep.swift deleted file mode 100644 index 8b137891..00000000 --- a/Sources/ConnectionPoolModule/gitkeep.swift +++ /dev/null @@ -1 +0,0 @@ - From 8fbf8ff7309921ebe73a9500b6d6a8bca161861b Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 13 Oct 2023 16:38:59 +0200 Subject: [PATCH 193/292] Add `PooledConnection` protocol (#417) --- .../ConnectionPoolModule/ConnectionPool.swift | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 Sources/ConnectionPoolModule/ConnectionPool.swift diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift new file mode 100644 index 00000000..290e0679 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -0,0 +1,58 @@ +/// A connection that can be pooled in a ``ConnectionPool`` +public protocol PooledConnection: AnyObject, Sendable { + /// The connections identifier type. + associatedtype ID: Hashable & Sendable + + /// The connections identifier. The identifier is passed to + /// the connection factory method and must stay attached to + /// the connection at all times. It must not change during + /// the connections lifetime. + var id: ID { get } + + /// A method to register closures that are invoked when the + /// connection is closed. If the connection closed unexpectedly + /// the closure shall be called with the underlying error. + /// In most NIO clients this can be easily implemented by + /// attaching to the `channel.closeFuture`: + /// ``` + /// func onClose( + /// _ closure: @escaping @Sendable ((any Error)?) -> () + /// ) { + /// channel.closeFuture.whenComplete { _ in + /// closure(previousError) + /// } + /// } + /// ``` + func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) + + /// Close the running connection. Once the close has completed + /// closures that were registered in `onClose` must be + /// invoked. + func close() +} + +/// A connection id generator. Its returned connection IDs will +/// be used when creating new ``PooledConnection``s +public protocol ConnectionIDGeneratorProtocol: Sendable { + /// The connections identifier type. + associatedtype ID: Hashable & Sendable + + /// The next connection ID that shall be used. + func next() -> ID +} + +/// A keep alive behavior for connections maintained by the pool +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public protocol ConnectionKeepAliveBehavior: Sendable { + /// the connection type + associatedtype Connection: PooledConnection + + /// The time after which a keep-alive shall + /// be triggered. + /// If nil is returned, keep-alive is deactivated + var keepAliveFrequency: Duration? { get } + + /// This method is invoked when the keep-alive shall be + /// run. + func runKeepAlive(for connection: Connection) async throws +} From 358fa598ae6fc2fc1cde213a0d2e8bd1eaf5b2eb Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 16 Oct 2023 13:20:09 +0200 Subject: [PATCH 194/292] Add `ConnectionIDGenerator` and `NoOpKeepAliveBehavior` (#418) --- .../ConnectionIDGenerator.swift | 15 ++++ .../NoKeepAliveBehavior.swift | 8 ++ .../ConnectionIDGeneratorTests.swift | 22 ++++++ .../Mocks/MockConnection.swift | 74 +++++++++++++++++++ .../NoKeepAliveBehaviorTests.swift | 10 +++ Tests/ConnectionPoolModuleTests/gitkeep.swift | 1 - 6 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 Sources/ConnectionPoolModule/ConnectionIDGenerator.swift create mode 100644 Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift create mode 100644 Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift create mode 100644 Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift delete mode 100644 Tests/ConnectionPoolModuleTests/gitkeep.swift diff --git a/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift b/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift new file mode 100644 index 00000000..b428d805 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionIDGenerator.swift @@ -0,0 +1,15 @@ +import Atomics + +public struct ConnectionIDGenerator: ConnectionIDGeneratorProtocol { + static let globalGenerator = ConnectionIDGenerator() + + private let atomic: ManagedAtomic + + public init() { + self.atomic = .init(0) + } + + public func next() -> Int { + return self.atomic.loadThenWrappingIncrement(ordering: .relaxed) + } +} diff --git a/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift b/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift new file mode 100644 index 00000000..0a7b2dee --- /dev/null +++ b/Sources/ConnectionPoolModule/NoKeepAliveBehavior.swift @@ -0,0 +1,8 @@ +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct NoOpKeepAliveBehavior: ConnectionKeepAliveBehavior { + public var keepAliveFrequency: Duration? { nil } + + public func runKeepAlive(for connection: Connection) async throws {} + + public init(connectionType: Connection.Type) {} +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift new file mode 100644 index 00000000..fb0bfce1 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift @@ -0,0 +1,22 @@ +import _ConnectionPoolModule +import XCTest + +final class ConnectionIDGeneratorTests: XCTestCase { + func testGenerateConnectionIDs() async { + let idGenerator = ConnectionIDGenerator() + + XCTAssertEqual(idGenerator.next(), 0) + XCTAssertEqual(idGenerator.next(), 1) + XCTAssertEqual(idGenerator.next(), 2) + + await withTaskGroup(of: Void.self) { taskGroup in + for _ in 0..<1000 { + taskGroup.addTask { + _ = idGenerator.next() + } + } + } + + XCTAssertEqual(idGenerator.next(), 1003) + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift new file mode 100644 index 00000000..6a8ed297 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift @@ -0,0 +1,74 @@ +import DequeModule +@testable import _ConnectionPoolModule + +// Sendability enforced through the lock +final class MockConnection: PooledConnection, @unchecked Sendable { + typealias ID = Int + + let id: ID + + private enum State { + case running([@Sendable ((any Error)?) -> ()]) + case closing([@Sendable ((any Error)?) -> ()]) + case closed + } + + private let lock = NIOLock() + private var _state = State.running([]) + + init(id: Int) { + self.id = id + } + + func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { + let enqueued = self.lock.withLock { () -> Bool in + switch self._state { + case .closed: + return false + + case .running(var callbacks): + callbacks.append(closure) + self._state = .running(callbacks) + return true + + case .closing(var callbacks): + callbacks.append(closure) + self._state = .closing(callbacks) + return true + } + } + + if !enqueued { + closure(nil) + } + } + + func close() { + self.lock.withLock { + switch self._state { + case .running(let callbacks): + self._state = .closing(callbacks) + + case .closing, .closed: + break + } + } + } + + func closeIfClosing() { + let callbacks = self.lock.withLock { () -> [@Sendable ((any Error)?) -> ()] in + switch self._state { + case .running, .closed: + return [] + + case .closing(let callbacks): + self._state = .closed + return callbacks + } + } + + for callback in callbacks { + callback(nil) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift new file mode 100644 index 00000000..b817ce19 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -0,0 +1,10 @@ +import _ConnectionPoolModule +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class NoKeepAliveBehaviorTests: XCTestCase { + func testNoKeepAlive() { + let keepAliveBehavior = NoOpKeepAliveBehavior(connectionType: MockConnection.self) + XCTAssertNil(keepAliveBehavior.keepAliveFrequency) + } +} diff --git a/Tests/ConnectionPoolModuleTests/gitkeep.swift b/Tests/ConnectionPoolModuleTests/gitkeep.swift deleted file mode 100644 index 8b137891..00000000 --- a/Tests/ConnectionPoolModuleTests/gitkeep.swift +++ /dev/null @@ -1 +0,0 @@ - From f5a04aab09b382e30129b8a86d7284412b549435 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 16 Oct 2023 15:33:30 +0200 Subject: [PATCH 195/292] Add `OneElementFastSequence` to be used in `ConnectionPool` (#420) --- .../OneElementFastSequence.swift | 151 ++++++++++++++++++ .../OneElementFastSequence.swift | 70 ++++++++ 2 files changed, 221 insertions(+) create mode 100644 Sources/ConnectionPoolModule/OneElementFastSequence.swift create mode 100644 Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift diff --git a/Sources/ConnectionPoolModule/OneElementFastSequence.swift b/Sources/ConnectionPoolModule/OneElementFastSequence.swift new file mode 100644 index 00000000..1bb3b8e4 --- /dev/null +++ b/Sources/ConnectionPoolModule/OneElementFastSequence.swift @@ -0,0 +1,151 @@ +/// A `Sequence` that does not heap allocate, if it only carries a single element +@usableFromInline +struct OneElementFastSequence: Sequence { + @usableFromInline + enum Base { + case none(reserveCapacity: Int) + case one(Element, reserveCapacity: Int) + case n([Element]) + } + + @usableFromInline + private(set) var base: Base + + @inlinable + init() { + self.base = .none(reserveCapacity: 0) + } + + @inlinable + init(_ element: Element) { + self.base = .one(element, reserveCapacity: 1) + } + + @inlinable + init(_ collection: some Collection) { + switch collection.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(collection.first!, reserveCapacity: 0) + default: + if let collection = collection as? Array { + self.base = .n(collection) + } else { + self.base = .n(Array(collection)) + } + } + } + + @usableFromInline + var count: Int { + switch self.base { + case .none: + return 0 + case .one: + return 1 + case .n(let array): + return array.count + } + } + + @inlinable + var first: Element? { + switch self.base { + case .none: + return nil + case .one(let element, _): + return element + case .n(let array): + return array.first + } + } + + @usableFromInline + var isEmpty: Bool { + switch self.base { + case .none: + return true + case .one, .n: + return false + } + } + + @inlinable + mutating func reserveCapacity(_ minimumCapacity: Int) { + switch self.base { + case .none(let reservedCapacity): + self.base = .none(reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .one(let element, let reservedCapacity): + self.base = .one(element, reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .n(var array): + self.base = .none(reserveCapacity: 0) // prevent CoW + array.reserveCapacity(minimumCapacity) + self.base = .n(array) + } + } + + @inlinable + mutating func append(_ element: Element) { + switch self.base { + case .none(let reserveCapacity): + self.base = .one(element, reserveCapacity: reserveCapacity) + case .one(let existing, let reserveCapacity): + var new = [Element]() + new.reserveCapacity(reserveCapacity) + new.append(existing) + new.append(element) + self.base = .n(new) + case .n(var existing): + self.base = .none(reserveCapacity: 0) // prevent CoW + existing.append(element) + self.base = .n(existing) + } + } + + @inlinable + func makeIterator() -> Iterator { + Iterator(self) + } + + @usableFromInline + struct Iterator: IteratorProtocol { + @usableFromInline private(set) var index: Int = 0 + @usableFromInline private(set) var backing: OneElementFastSequence + + @inlinable + init(_ backing: OneElementFastSequence) { + self.backing = backing + } + + @inlinable + mutating func next() -> Element? { + switch self.backing.base { + case .none: + return nil + case .one(let element, _): + if self.index == 0 { + self.index += 1 + return element + } + return nil + + case .n(let array): + if self.index < array.endIndex { + defer { self.index += 1} + return array[self.index] + } + return nil + } + } + } +} + +extension OneElementFastSequence: Equatable where Element: Equatable {} +extension OneElementFastSequence.Base: Equatable where Element: Equatable {} + +extension OneElementFastSequence: Hashable where Element: Hashable {} +extension OneElementFastSequence.Base: Hashable where Element: Hashable {} + +extension OneElementFastSequence: Sendable where Element: Sendable {} +extension OneElementFastSequence.Base: Sendable where Element: Sendable {} diff --git a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift b/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift new file mode 100644 index 00000000..8098438f --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift @@ -0,0 +1,70 @@ +@testable import _ConnectionPoolModule +import XCTest + +final class OneElementFastSequenceTests: XCTestCase { + func testCountIsEmptyAndIterator() async { + var sequence = OneElementFastSequence() + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(sequence.isEmpty, true) + XCTAssertEqual(sequence.first, nil) + XCTAssertEqual(Array(sequence), []) + sequence.append(1) + XCTAssertEqual(sequence.count, 1) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1]) + sequence.append(2) + XCTAssertEqual(sequence.count, 2) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1, 2]) + sequence.append(3) + XCTAssertEqual(sequence.count, 3) + XCTAssertEqual(sequence.isEmpty, false) + XCTAssertEqual(sequence.first, 1) + XCTAssertEqual(Array(sequence), [1, 2, 3]) + } + + func testReserveCapacityIsForwarded() { + var emptySequence = OneElementFastSequence() + emptySequence.reserveCapacity(8) + emptySequence.append(1) + emptySequence.append(2) + guard case .n(let array) = emptySequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + + var oneElemSequence = OneElementFastSequence(1) + oneElemSequence.reserveCapacity(8) + oneElemSequence.append(2) + guard case .n(let array) = oneElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + + var twoElemSequence = OneElementFastSequence([1, 2]) + twoElemSequence.reserveCapacity(8) + guard case .n(let array) = twoElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) + } + + func testNewSequenceSlowPath() { + let sequence = OneElementFastSequence("AB".utf8) + XCTAssertEqual(Array(sequence), [UInt8(ascii: "A"), UInt8(ascii: "B")]) + } + + func testSingleItem() { + let sequence = OneElementFastSequence("A".utf8) + XCTAssertEqual(Array(sequence), [UInt8(ascii: "A")]) + } + + func testEmptyCollection() { + let sequence = OneElementFastSequence("".utf8) + XCTAssertTrue(sequence.isEmpty) + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(Array(sequence), []) + } +} From 5e75c9e24db385870e19578404635891490314bf Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 16 Oct 2023 15:40:42 +0200 Subject: [PATCH 196/292] Add `Max2Sequence` to be used in `ConnectionPool` (#419) --- .../ConnectionPoolModule/Max2Sequence.swift | 95 +++++++++++++++++++ .../Max2SequenceTests.swift | 60 ++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 Sources/ConnectionPoolModule/Max2Sequence.swift create mode 100644 Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift diff --git a/Sources/ConnectionPoolModule/Max2Sequence.swift b/Sources/ConnectionPoolModule/Max2Sequence.swift new file mode 100644 index 00000000..6c330067 --- /dev/null +++ b/Sources/ConnectionPoolModule/Max2Sequence.swift @@ -0,0 +1,95 @@ +// A `Sequence` that can contain at most two elements. However it does not heap allocate. +@usableFromInline +struct Max2Sequence: Sequence { + @usableFromInline + private(set) var first: Element? + @usableFromInline + private(set) var second: Element? + + @inlinable + var count: Int { + if self.first == nil { return 0 } + if self.second == nil { return 1 } + return 2 + } + + @inlinable + var isEmpty: Bool { + self.first == nil + } + + @inlinable + init(_ first: Element?, _ second: Element? = nil) { + if let first = first { + self.first = first + self.second = second + } else { + self.first = second + self.second = nil + } + } + + @inlinable + init() { + self.first = nil + self.second = nil + } + + @inlinable + func makeIterator() -> Iterator { + Iterator(first: self.first, second: self.second) + } + + @usableFromInline + struct Iterator: IteratorProtocol { + @usableFromInline + let first: Element? + @usableFromInline + let second: Element? + + @usableFromInline + private(set) var index: UInt8 = 0 + + @inlinable + init(first: Element?, second: Element?) { + self.first = first + self.second = second + self.index = 0 + } + + @inlinable + mutating func next() -> Element? { + switch self.index { + case 0: + self.index += 1 + return self.first + case 1: + self.index += 1 + return self.second + default: + return nil + } + } + } + + @inlinable + mutating func append(_ element: Element) { + precondition(self.second == nil) + if self.first == nil { + self.first = element + } else if self.second == nil { + self.second = element + } else { + fatalError("Max2Sequence can only hold two Elements.") + } + } + + @inlinable + func map(_ transform: (Element) throws -> (NewElement)) rethrows -> Max2Sequence { + try Max2Sequence(self.first.flatMap(transform), self.second.flatMap(transform)) + } +} + +extension Max2Sequence: Equatable where Element: Equatable {} +extension Max2Sequence: Hashable where Element: Hashable {} +extension Max2Sequence: Sendable where Element: Sendable {} diff --git a/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift new file mode 100644 index 00000000..081e867b --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift @@ -0,0 +1,60 @@ +@testable import _ConnectionPoolModule +import XCTest + +final class Max2SequenceTests: XCTestCase { + func testCountAndIsEmpty() async { + var sequence = Max2Sequence() + XCTAssertEqual(sequence.count, 0) + XCTAssertEqual(sequence.isEmpty, true) + sequence.append(1) + XCTAssertEqual(sequence.count, 1) + XCTAssertEqual(sequence.isEmpty, false) + sequence.append(2) + XCTAssertEqual(sequence.count, 2) + XCTAssertEqual(sequence.isEmpty, false) + } + + func testOptionalInitializer() { + let emptySequence = Max2Sequence(nil, nil) + XCTAssertEqual(emptySequence.count, 0) + XCTAssertEqual(emptySequence.isEmpty, true) + var emptySequenceIterator = emptySequence.makeIterator() + XCTAssertNil(emptySequenceIterator.next()) + XCTAssertNil(emptySequenceIterator.next()) + XCTAssertNil(emptySequenceIterator.next()) + + let oneElemSequence1 = Max2Sequence(1, nil) + XCTAssertEqual(oneElemSequence1.count, 1) + XCTAssertEqual(oneElemSequence1.isEmpty, false) + var oneElemSequence1Iterator = oneElemSequence1.makeIterator() + XCTAssertEqual(oneElemSequence1Iterator.next(), 1) + XCTAssertNil(oneElemSequence1Iterator.next()) + XCTAssertNil(oneElemSequence1Iterator.next()) + + let oneElemSequence2 = Max2Sequence(nil, 2) + XCTAssertEqual(oneElemSequence2.count, 1) + XCTAssertEqual(oneElemSequence2.isEmpty, false) + var oneElemSequence2Iterator = oneElemSequence2.makeIterator() + XCTAssertEqual(oneElemSequence2Iterator.next(), 2) + XCTAssertNil(oneElemSequence2Iterator.next()) + XCTAssertNil(oneElemSequence2Iterator.next()) + + let twoElemSequence = Max2Sequence(1, 2) + XCTAssertEqual(twoElemSequence.count, 2) + XCTAssertEqual(twoElemSequence.isEmpty, false) + var twoElemSequenceIterator = twoElemSequence.makeIterator() + XCTAssertEqual(twoElemSequenceIterator.next(), 1) + XCTAssertEqual(twoElemSequenceIterator.next(), 2) + XCTAssertNil(twoElemSequenceIterator.next()) + } + + func testMap() { + let twoElemSequence = Max2Sequence(1, 2).map({ "\($0)" }) + XCTAssertEqual(twoElemSequence.count, 2) + XCTAssertEqual(twoElemSequence.isEmpty, false) + var twoElemSequenceIterator = twoElemSequence.makeIterator() + XCTAssertEqual(twoElemSequenceIterator.next(), "1") + XCTAssertEqual(twoElemSequenceIterator.next(), "2") + XCTAssertNil(twoElemSequenceIterator.next()) + } +} From a57baa7f7233646449f1fde2d3fd5670de7df870 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 17 Oct 2023 12:59:39 +0200 Subject: [PATCH 197/292] Add `ConnectionRequestProtocol`, `ConnectionPoolError` and `ConnectionPoolConfiguration` (#421) --- .../ConnectionPoolModule/ConnectionPool.swift | 58 +++++++++++++++++++ .../ConnectionPoolError.swift | 16 +++++ .../ConnectionRequest.swift | 20 +++++++ 3 files changed, 94 insertions(+) create mode 100644 Sources/ConnectionPoolModule/ConnectionPoolError.swift create mode 100644 Sources/ConnectionPoolModule/ConnectionRequest.swift diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 290e0679..825c3ab3 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -56,3 +56,61 @@ public protocol ConnectionKeepAliveBehavior: Sendable { /// run. func runKeepAlive(for connection: Connection) async throws } + +/// A request to get a connection from the `ConnectionPool` +public protocol ConnectionRequestProtocol: Sendable { + /// A connection lease request ID type. + associatedtype ID: Hashable & Sendable + /// The leased connection type + associatedtype Connection: PooledConnection + + /// A connection lease request ID. This ID must be generated + /// by users of the `ConnectionPool` outside the + /// `ConnectionPool`. It is not generated inside the pool like + /// the `ConnectionID`s. The lease request ID must be unique + /// and must not change, if your implementing type is a + /// reference type. + var id: ID { get } + + /// A function that is called with a connection or a + /// `PoolError`. + func complete(with: Result) +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct ConnectionPoolConfiguration { + /// The minimum number of connections to preserve in the pool. + /// + /// If the pool is mostly idle and the remote servers closes + /// idle connections, + /// the `ConnectionPool` will initiate new outbound + /// connections proactively to avoid the number of available + /// connections dropping below this number. + public var minimumConnectionCount: Int + + /// Between the `minimumConnectionCount` and + /// `maximumConnectionSoftLimit` the connection pool creates + /// _preserved_ connections. Preserved connections are closed + /// if they have been idle for ``idleTimeout``. + public var maximumConnectionSoftLimit: Int + + /// The maximum number of connections for this pool, that can + /// exist at any point in time. The pool can create _overflow_ + /// connections, if all connections are leased, and the + /// `maximumConnectionHardLimit` > `maximumConnectionSoftLimit ` + /// Overflow connections are closed immediately as soon as they + /// become idle. + public var maximumConnectionHardLimit: Int + + /// The time that a _preserved_ idle connection stays in the + /// pool before it is closed. + public var idleTimeout: Duration + + /// initializer + public init() { + self.minimumConnectionCount = 0 + self.maximumConnectionSoftLimit = 16 + self.maximumConnectionHardLimit = 16 + self.idleTimeout = .seconds(60) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPoolError.swift b/Sources/ConnectionPoolModule/ConnectionPoolError.swift new file mode 100644 index 00000000..1f1e1d2c --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPoolError.swift @@ -0,0 +1,16 @@ + +public struct ConnectionPoolError: Error, Hashable { + enum Base: Error, Hashable { + case requestCancelled + case poolShutdown + } + + private let base: Base + + init(_ base: Base) { self.base = base } + + /// The connection requests got cancelled + public static let requestCancelled = ConnectionPoolError(.requestCancelled) + /// The connection requests can't be fulfilled as the pool has already been shutdown + public static let poolShutdown = ConnectionPoolError(.poolShutdown) +} diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift new file mode 100644 index 00000000..34b77084 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -0,0 +1,20 @@ + +public struct ConnectionRequest: ConnectionRequestProtocol { + public typealias ID = Int + + public var id: ID + + private var continuation: CheckedContinuation + + init( + id: Int, + continuation: CheckedContinuation + ) { + self.id = id + self.continuation = continuation + } + + public func complete(with result: Result) { + self.continuation.resume(with: result) + } +} From c80a9347024892434d7c214eab8d194ee3a71bc0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 17 Oct 2023 20:39:37 +0200 Subject: [PATCH 198/292] Add `ConnectionPoolObservabilityDelegate` (#422) --- .../ConnectionPoolObservabilityDelegate.swift | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift diff --git a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift new file mode 100644 index 00000000..35f30dcb --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift @@ -0,0 +1,62 @@ + +public protocol ConnectionPoolObservabilityDelegate: Sendable { + associatedtype ConnectionID: Hashable & Sendable + + /// The connection with the given ID has started trying to establish a connection. The outcome + /// of the connection will be reported as either ``connectSucceeded(id:streamCapacity:)`` or + /// ``connectFailed(id:error:)``. + func startedConnecting(id: ConnectionID) + + /// A connection attempt failed with the given error. After some period of + /// time ``startedConnecting(id:)`` may be called again. + func connectFailed(id: ConnectionID, error: Error) + + /// A connection was established on the connection with the given ID. `streamCapacity` streams are + /// available to use on the connection. The maximum number of available streams may change over + /// time and is reported via ````. The + func connectSucceeded(id: ConnectionID, streamCapacity: UInt16) + + /// The utlization of the connection changed; a stream may have been used, returned or the + /// maximum number of concurrent streams available on the connection changed. + func connectionUtilizationChanged(id:ConnectionID, streamsUsed: UInt16, streamCapacity: UInt16) + + func keepAliveTriggered(id: ConnectionID) + + func keepAliveSucceeded(id: ConnectionID) + + func keepAliveFailed(id: ConnectionID, error: Error) + + /// The remote peer is quiescing the connection: no new streams will be created on it. The + /// connection will eventually be closed and removed from the pool. + func connectionClosing(id: ConnectionID) + + /// The connection was closed. The connection may be established again in the future (notified + /// via ``startedConnecting(id:)``). + func connectionClosed(id: ConnectionID, error: Error?) + + func requestQueueDepthChanged(_ newDepth: Int) +} + +public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { + public init(connectionIDType: ConnectionID.Type) {} + + public func startedConnecting(id: ConnectionID) {} + + public func connectFailed(id: ConnectionID, error: Error) {} + + public func connectSucceeded(id: ConnectionID, streamCapacity: UInt16) {} + + public func connectionUtilizationChanged(id: ConnectionID, streamsUsed: UInt16, streamCapacity: UInt16) {} + + public func keepAliveTriggered(id: ConnectionID) {} + + public func keepAliveSucceeded(id: ConnectionID) {} + + public func keepAliveFailed(id: ConnectionID, error: Error) {} + + public func connectionClosing(id: ConnectionID) {} + + public func connectionClosed(id: ConnectionID, error: Error?) {} + + public func requestQueueDepthChanged(_ newDepth: Int) {} +} From 8babbcff00e879173779f0d59b3fa413af4282c9 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Wed, 18 Oct 2023 16:08:07 +0330 Subject: [PATCH 199/292] Fix `PostgresDecodable` inference for `RawRepresentable` enums (#423) --- .../RawRepresentable+PostgresCodable.swift | 2 +- .../PSQLIntegrationTests.swift | 39 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift index 4d6c20c4..ea097963 100644 --- a/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/RawRepresentable+PostgresCodable.swift @@ -19,7 +19,7 @@ extension PostgresEncodable where Self: RawRepresentable, RawValue: PostgresEnco } extension PostgresDecodable where Self: RawRepresentable, RawValue: PostgresDecodable, RawValue._DecodableType == RawValue { - init( + public init( from buffer: inout ByteBuffer, type: PostgresDataType, format: PostgresFormat, diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 4b2b9950..0550dc77 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -1,6 +1,6 @@ import XCTest import Logging -@testable import PostgresNIO +import PostgresNIO import NIOCore import NIOPosix import NIOTestUtils @@ -252,7 +252,7 @@ final class IntegrationTests: XCTestCase { XCTAssertNoThrow(result = try conn?.query(""" SELECT \(Decimal(string: "123456.789123")!)::numeric as numeric, - \(Decimal(string: "-123456.789123")!)::numeric as numeric_negative + \(Decimal(string: "-123456.789123")!)::numeric as numeric_negative """, logger: .psqlTest).wait()) XCTAssertEqual(result?.rows.count, 1) @@ -263,6 +263,41 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(cells?.1, Decimal(string: "-123456.789123")) } + func testDecodeRawRepresentables() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + enum StringRR: String, PostgresDecodable { + case a + } + + enum IntRR: Int, PostgresDecodable { + case b + } + + let stringValue = StringRR.a + let intValue = IntRR.b + + var result: PostgresQueryResult? + XCTAssertNoThrow(result = try conn?.query(""" + SELECT + \(stringValue.rawValue)::varchar as string, + \(intValue.rawValue)::int8 as int + """, logger: .psqlTest).wait()) + XCTAssertEqual(result?.rows.count, 1) + + var cells: (StringRR, IntRR)? + XCTAssertNoThrow(cells = try result?.rows.first?.decode((StringRR, IntRR).self, context: .default)) + + XCTAssertEqual(cells?.0, stringValue) + XCTAssertEqual(cells?.1, intValue) + } + func testRoundTripUUID() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } From 56419669833c265c4096df5341ae22f5753849cd Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 18 Oct 2023 15:46:13 +0200 Subject: [PATCH 200/292] Add `PoolStateMachine.RequestQueue` (#424) --- .../ConnectionRequest.swift | 6 +- .../OneElementFastSequence.swift | 2 +- .../PoolStateMachine+RequestQueue.swift | 71 +++++++++ .../PoolStateMachine.swift | 74 +++++++++ .../ConnectionRequestTests.swift | 27 ++++ .../Mocks/MockRequest.swift | 28 ++++ .../Mocks/MockTimerCancellationToken.swift | 16 ++ .../OneElementFastSequence.swift | 2 +- .../PoolStateMachine+RequestQueueTests.swift | 147 ++++++++++++++++++ .../PoolStateMachineTests.swift | 14 ++ 10 files changed, 383 insertions(+), 4 deletions(-) create mode 100644 Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift create mode 100644 Sources/ConnectionPoolModule/PoolStateMachine.swift create mode 100644 Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift create mode 100644 Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift create mode 100644 Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 34b77084..fd01bb76 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -4,11 +4,13 @@ public struct ConnectionRequest: ConnectionRequest public var id: ID - private var continuation: CheckedContinuation + @usableFromInline + private(set) var continuation: CheckedContinuation + @inlinable init( id: Int, - continuation: CheckedContinuation + continuation: CheckedContinuation ) { self.id = id self.continuation = continuation diff --git a/Sources/ConnectionPoolModule/OneElementFastSequence.swift b/Sources/ConnectionPoolModule/OneElementFastSequence.swift index 1bb3b8e4..3c3bfaa0 100644 --- a/Sources/ConnectionPoolModule/OneElementFastSequence.swift +++ b/Sources/ConnectionPoolModule/OneElementFastSequence.swift @@ -17,7 +17,7 @@ struct OneElementFastSequence: Sequence { } @inlinable - init(_ element: Element) { + init(element: Element) { self.base = .one(element, reserveCapacity: 1) } diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift new file mode 100644 index 00000000..7e3c6607 --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift @@ -0,0 +1,71 @@ +import DequeModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + + /// A request queue, which can enqueue requests in O(1), dequeue requests in O(1) and even cancel requests in O(1). + /// + /// While enqueueing and dequeueing on O(1) is trivial, cancellation is hard, as it normally requires a removal within the + /// underlying Deque. However thanks to having an additional `requests` dictionary, we can remove the cancelled + /// request from the dictionary and keep it inside the queue. Whenever we pop a request from the deque, we validate + /// that it hasn't been cancelled in the meantime by checking if the popped request is still in the `requests` dictionary. + @usableFromInline + struct RequestQueue { + @usableFromInline + private(set) var queue: Deque + + @usableFromInline + private(set) var requests: [RequestID: Request] + + @inlinable + var count: Int { + self.requests.count + } + + @inlinable + var isEmpty: Bool { + self.count == 0 + } + + @usableFromInline + init() { + self.queue = .init(minimumCapacity: 256) + self.requests = .init(minimumCapacity: 256) + } + + @inlinable + mutating func queue(_ request: Request) { + self.requests[request.id] = request + self.queue.append(request.id) + } + + @inlinable + mutating func pop(max: UInt16) -> OneElementFastSequence { + var result = OneElementFastSequence() + result.reserveCapacity(Int(max)) + var popped = 0 + while let requestID = self.queue.popFirst(), popped < max { + if let requestIndex = self.requests.index(forKey: requestID) { + popped += 1 + result.append(self.requests.remove(at: requestIndex).value) + } + } + + assert(result.count <= max) + return result + } + + @inlinable + mutating func remove(_ requestID: RequestID) -> Request? { + self.requests.removeValue(forKey: requestID) + } + + @inlinable + mutating func removeAll() -> OneElementFastSequence { + let result = OneElementFastSequence(self.requests.values) + self.requests.removeAll() + self.queue.removeAll() + return result + } + } +} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift new file mode 100644 index 00000000..a3962790 --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -0,0 +1,74 @@ +#if canImport(Darwin) +import Darwin +#else +import Glibc +#endif + +@usableFromInline +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PoolConfiguration { + /// The minimum number of connections to preserve in the pool. + /// + /// If the pool is mostly idle and the remote servers closes idle connections, + /// the `ConnectionPool` will initiate new outbound connections proactively + /// to avoid the number of available connections dropping below this number. + @usableFromInline + var minimumConnectionCount: Int = 0 + + /// The maximum number of connections to for this pool, to be preserved. + @usableFromInline + var maximumConnectionSoftLimit: Int = 10 + + @usableFromInline + var maximumConnectionHardLimit: Int = 10 + + @usableFromInline + var keepAliveDuration: Duration? + + @usableFromInline + var idleTimeoutDuration: Duration = .seconds(30) +} + +@usableFromInline +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PoolStateMachine< + Connection: PooledConnection, + ConnectionIDGenerator: ConnectionIDGeneratorProtocol, + ConnectionID: Hashable & Sendable, + Request: ConnectionRequestProtocol, + RequestID, + TimerCancellationToken +> where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { + + @usableFromInline + struct Timer: Hashable, Sendable { + @usableFromInline + enum Usecase: Sendable { + case backoff + case idleTimeout + case keepAlive + } + + @usableFromInline + var connectionID: ConnectionID + + @usableFromInline + var timerID: Int + + @usableFromInline + var duration: Duration + + @usableFromInline + var usecase: Usecase + + @inlinable + init(connectionID: ConnectionID, timerID: Int, duration: Duration, usecase: Usecase) { + self.connectionID = connectionID + self.timerID = timerID + self.duration = duration + self.usecase = usecase + } + } + + +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift new file mode 100644 index 00000000..5845267f --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -0,0 +1,27 @@ +@testable import _ConnectionPoolModule +import XCTest + +final class ConnectionRequestTests: XCTestCase { + + func testHappyPath() async throws { + let mockConnection = MockConnection(id: 1) + let connection = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let request = ConnectionRequest(id: 42, continuation: continuation) + XCTAssertEqual(request.id, 42) + continuation.resume(with: .success(mockConnection)) + } + + XCTAssert(connection === mockConnection) + } + + func testSadPath() async throws { + do { + _ = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + continuation.resume(with: .failure(ConnectionPoolError.requestCancelled)) + } + XCTFail("This point should not be reached") + } catch { + XCTAssertEqual(error as? ConnectionPoolError, .requestCancelled) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift new file mode 100644 index 00000000..6aaa9c91 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift @@ -0,0 +1,28 @@ +import _ConnectionPoolModule + +final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { + typealias Connection = MockConnection + + struct ID: Hashable { + var objectID: ObjectIdentifier + + init(_ request: MockRequest) { + self.objectID = ObjectIdentifier(request) + } + } + + var id: ID { ID(self) } + + + static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { + lhs.id == rhs.id + } + + func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + func complete(with: Result) { + + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift new file mode 100644 index 00000000..20434450 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift @@ -0,0 +1,16 @@ +@testable import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct MockTimerCancellationToken: Hashable, Sendable { + var connectionID: MockConnection.ID + var timerID: Int + var duration: Duration + var usecase: TestPoolStateMachine.Timer.Usecase + + init(_ timer: TestPoolStateMachine.Timer) { + self.connectionID = timer.connectionID + self.timerID = timer.timerID + self.duration = timer.duration + self.usecase = timer.usecase + } +} diff --git a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift b/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift index 8098438f..a086341e 100644 --- a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift +++ b/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift @@ -35,7 +35,7 @@ final class OneElementFastSequenceTests: XCTestCase { } XCTAssertEqual(array.capacity, 8) - var oneElemSequence = OneElementFastSequence(1) + var oneElemSequence = OneElementFastSequence(element: 1) oneElemSequence.reserveCapacity(8) oneElemSequence.append(2) guard case .n(let array) = oneElemSequence.base else { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift new file mode 100644 index 00000000..0231da51 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -0,0 +1,147 @@ +@testable import _ConnectionPoolModule +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_RequestQueueTests: XCTestCase { + + typealias TestQueue = TestPoolStateMachine.RequestQueue + + func testHappyPath() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + let request1 = MockRequest() + queue.queue(request1) + XCTAssertEqual(queue.count, 1) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + func testEnqueueAndPopMultipleRequests() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1, request2, request3])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testEnqueueAndPopOnlyOne() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 1) + XCTAssert(popResult.elementsEqual([request1])) + XCTAssertFalse(queue.isEmpty) + XCTAssertEqual(queue.count, 2) + + let removeAllResult = queue.removeAll() + XCTAssert(Set(removeAllResult) == [request2, request3]) + } + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testCancellation() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + let returnedRequest2 = queue.remove(request2.id) + XCTAssert(returnedRequest2 === request2) + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + } + + // still retained by the deque inside the queue + XCTAssertEqual(queue.requests.count, 2) + XCTAssertEqual(queue.queue.count, 3) + + do { + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + let popResult = queue.pop(max: 3) + XCTAssert(popResult.elementsEqual([request1, request3])) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } + + func testRemoveAllAfterCancellation() { + var queue = TestQueue() + XCTAssert(queue.isEmpty) + + var request1 = MockRequest() + queue.queue(request1) + var request2 = MockRequest() + queue.queue(request2) + var request3 = MockRequest() + queue.queue(request3) + + do { + XCTAssertEqual(queue.count, 3) + let returnedRequest2 = queue.remove(request2.id) + XCTAssert(returnedRequest2 === request2) + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + } + + // still retained by the deque inside the queue + XCTAssertEqual(queue.requests.count, 2) + XCTAssertEqual(queue.queue.count, 3) + + do { + XCTAssertEqual(queue.count, 2) + XCTAssertFalse(queue.isEmpty) + let removeAllResult = queue.removeAll() + XCTAssert(Set(removeAllResult) == [request1, request3]) + XCTAssert(queue.isEmpty) + XCTAssertEqual(queue.count, 0) + } + + XCTAssert(isKnownUniquelyReferenced(&request1)) + XCTAssert(isKnownUniquelyReferenced(&request2)) + XCTAssert(isKnownUniquelyReferenced(&request3)) + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift new file mode 100644 index 00000000..ee8cfdc6 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -0,0 +1,14 @@ +import NIOCore +import NIOEmbedded +import XCTest +@testable import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +typealias TestPoolStateMachine = PoolStateMachine< + MockConnection, + ConnectionIDGenerator, + MockConnection.ID, + MockRequest, + MockRequest.ID, + MockTimerCancellationToken +> From 20a8c340ed4984b6c85aabd27a38fa5b2d780ee0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 18 Oct 2023 22:37:54 +0200 Subject: [PATCH 201/292] Add `PoolStateMachine.ConnectionState` (#425) --- .../ConnectionPoolModule/Max2Sequence.swift | 10 + .../PoolStateMachine+ConnectionState.swift | 584 ++++++++++++++++++ .../PoolStateMachine.swift | 53 +- .../Mocks/MockTimerCancellationToken.swift | 18 +- ...oolStateMachine+ConnectionStateTests.swift | 264 ++++++++ 5 files changed, 904 insertions(+), 25 deletions(-) create mode 100644 Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift create mode 100644 Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift diff --git a/Sources/ConnectionPoolModule/Max2Sequence.swift b/Sources/ConnectionPoolModule/Max2Sequence.swift index 6c330067..0feccd68 100644 --- a/Sources/ConnectionPoolModule/Max2Sequence.swift +++ b/Sources/ConnectionPoolModule/Max2Sequence.swift @@ -90,6 +90,16 @@ struct Max2Sequence: Sequence { } } +extension Max2Sequence: ExpressibleByArrayLiteral { + @inlinable + init(arrayLiteral elements: Element...) { + precondition(elements.count <= 2) + var iterator = elements.makeIterator() + self.first = iterator.next() + self.second = iterator.next() + } +} + extension Max2Sequence: Equatable where Element: Equatable {} extension Max2Sequence: Hashable where Element: Hashable {} extension Max2Sequence: Sendable where Element: Sendable {} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift new file mode 100644 index 00000000..51ab5323 --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -0,0 +1,584 @@ +import Atomics + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + + @usableFromInline + struct KeepAliveAction { + @usableFromInline + var connection: Connection + @usableFromInline + var keepAliveTimerCancellationContinuation: TimerCancellationToken? + + @inlinable + init(connection: Connection, keepAliveTimerCancellationContinuation: TimerCancellationToken? = nil) { + self.connection = connection + self.keepAliveTimerCancellationContinuation = keepAliveTimerCancellationContinuation + } + } + + @usableFromInline + struct ConnectionTimer: Hashable, Sendable { + @usableFromInline + enum Usecase: Hashable, Sendable { + case backoff + case keepAlive + case idleTimeout + } + + @usableFromInline + var timerID: Int + + @usableFromInline + var connectionID: Connection.ID + + @usableFromInline + var usecase: Usecase + + @inlinable + init(timerID: Int, connectionID: Connection.ID, usecase: Usecase) { + self.timerID = timerID + self.connectionID = connectionID + self.usecase = usecase + } + } + + @usableFromInline + /// An connection state machine about the pool's view on the connection. + struct ConnectionState { + @usableFromInline + enum State { + @usableFromInline + enum KeepAlive { + case notScheduled + case scheduled(Timer) + case running(_ consumingStream: Bool) + + @inlinable + var usedStreams: UInt16 { + switch self { + case .notScheduled, .scheduled, .running(false): + return 0 + case .running(true): + return 1 + } + } + + @inlinable + var isRunning: Bool { + switch self { + case .running: + return true + case .notScheduled, .scheduled: + return false + } + } + + @inlinable + mutating func cancelTimerIfScheduled() -> TimerCancellationToken? { + switch self { + case .scheduled(let timer): + self = .notScheduled + return timer.cancellationContinuation + case .running, .notScheduled: + return nil + } + } + } + + @usableFromInline + struct Timer { + @usableFromInline + let timerID: Int + + @usableFromInline + private(set) var cancellationContinuation: TimerCancellationToken? + + @inlinable + init(id: Int) { + self.timerID = id + self.cancellationContinuation = nil + } + + @inlinable + mutating func registerCancellationContinuation(_ continuation: TimerCancellationToken) { + precondition(self.cancellationContinuation == nil) + self.cancellationContinuation = continuation + } + } + + /// The pool is creating a connection. Valid transitions are to: `.backingOff`, `.idle`, and `.closed` + case starting + /// The pool is waiting to retry establishing a connection. Valid transitions are to: `.closed`. + /// This means, the connection can be removed from the connections without cancelling external + /// state. The connection state can then be replaced by a new one. + case backingOff(Timer) + /// The connection is `idle` and ready to execute a new query. Valid transitions to: `.pingpong`, `.leased`, + /// `.closing` and `.closed` + case idle(Connection, maxStreams: UInt16, keepAlive: KeepAlive, idleTimer: Timer?) + /// The connection is leased and executing a query. Valid transitions to: `.idle` and `.closed` + case leased(Connection, usedStreams: UInt16, maxStreams: UInt16, keepAlive: KeepAlive) + /// The connection is closing. Valid transitions to: `.closed` + case closing(Connection) + /// The connection is closed. Final state. + case closed + } + + @usableFromInline + let id: Connection.ID + + @usableFromInline + private(set) var state: State = .starting + + @usableFromInline + private(set) var nextTimerID: Int = 0 + + @inlinable + init(id: Connection.ID) { + self.id = id + } + + @inlinable + var isIdle: Bool { + switch self.state { + case .idle(_, _, .notScheduled, _), .idle(_, _, .scheduled, _): + return true + case .idle(_, _, .running, _): + return false + case .backingOff, .starting, .closed, .closing, .leased: + return false + } + } + + @inlinable + var isAvailable: Bool { + switch self.state { + case .idle(_, let maxStreams, .running(true), _): + return maxStreams > 1 + case .idle(_, let maxStreams, let keepAlive, _): + return keepAlive.usedStreams < maxStreams + case .leased(_, let usedStreams, let maxStreams, let keepAlive): + return usedStreams + keepAlive.usedStreams < maxStreams + case .backingOff, .starting, .closed, .closing: + return false + } + } + + @usableFromInline + var isLeased: Bool { + switch self.state { + case .leased: + return true + case .backingOff, .starting, .closed, .closing, .idle: + return false + } + } + + @usableFromInline + var isIdleOrRunningKeepAlive: Bool { + switch self.state { + case .idle: + return true + case .backingOff, .starting, .closed, .closing, .leased: + return false + } + } + + @usableFromInline + var isConnected: Bool { + switch self.state { + case .idle, .leased: + return true + case .backingOff, .starting, .closed, .closing: + return false + } + } + + @inlinable + mutating func connected(_ connection: Connection, maxStreams: UInt16) -> ConnectionAvailableInfo { + switch self.state { + case .starting: + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .notScheduled, idleTimer: nil) + return .idle(availableStreams: maxStreams, newIdle: true) + case .backingOff, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func parkConnection(scheduleKeepAliveTimer: Bool, scheduleIdleTimeoutTimer: Bool) -> Max2Sequence { + var keepAliveTimer: ConnectionTimer? + var keepAliveTimerState: State.Timer? + var idleTimer: ConnectionTimer? + var idleTimerState: State.Timer? + + switch self.state { + case .backingOff, .starting, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(let connection, let maxStreams, .notScheduled, .none): + let keepAlive: State.KeepAlive + if scheduleKeepAliveTimer { + keepAliveTimerState = self.nextTimer() + keepAliveTimer = ConnectionTimer(timerID: keepAliveTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + keepAlive = .scheduled(keepAliveTimerState!) + } else { + keepAlive = .notScheduled + } + if scheduleIdleTimeoutTimer { + idleTimerState = self.nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .idleTimeout) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimerState) + return Max2Sequence(keepAliveTimer, idleTimer) + + case .idle(_, _, .scheduled, .some): + precondition(!scheduleKeepAliveTimer) + precondition(!scheduleIdleTimeoutTimer) + return Max2Sequence() + + case .idle(let connection, let maxStreams, .notScheduled, let idleTimer): + precondition(!scheduleIdleTimeoutTimer) + let keepAlive: State.KeepAlive + if scheduleKeepAliveTimer { + keepAliveTimerState = self.nextTimer() + keepAliveTimer = ConnectionTimer(timerID: keepAliveTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + keepAlive = .scheduled(keepAliveTimerState!) + } else { + keepAlive = .notScheduled + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimer) + return Max2Sequence(keepAliveTimer) + + case .idle(let connection, let maxStreams, .scheduled(let keepAliveTimer), .none): + precondition(!scheduleKeepAliveTimer) + + if scheduleIdleTimeoutTimer { + idleTimerState = self.nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .scheduled(keepAliveTimer), idleTimer: idleTimerState) + return Max2Sequence(idleTimer, nil) + + case .idle(let connection, let maxStreams, keepAlive: .running(let usingStream), idleTimer: .none): + if scheduleIdleTimeoutTimer { + idleTimerState = self.nextTimer() + idleTimer = ConnectionTimer(timerID: idleTimerState!.timerID, connectionID: self.id, usecase: .keepAlive) + } + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .running(usingStream), idleTimer: idleTimerState) + return Max2Sequence(keepAliveTimer, idleTimer) + + case .idle(_, _, keepAlive: .running(_), idleTimer: .some): + precondition(!scheduleKeepAliveTimer) + precondition(!scheduleIdleTimeoutTimer) + return Max2Sequence() + } + } + + @inlinable + mutating func nextTimer() -> State.Timer { + defer { self.nextTimerID += 1 } + return State.Timer(id: self.nextTimerID) + } + + /// The connection failed to start + @inlinable + mutating func failedToConnect() -> ConnectionTimer { + switch self.state { + case .starting: + let backoffTimerState = self.nextTimer() + self.state = .backingOff(backoffTimerState) + return ConnectionTimer(timerID: backoffTimerState.timerID, connectionID: self.id, usecase: .backoff) + + case .backingOff, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + /// Moves a connection, that has previously ``failedToConnect()`` back into the connecting state. + /// + /// - Returns: A ``TimerCancellationToken`` that was previously registered with the state machine + /// for the ``ConnectionTimer`` returned in ``failedToConnect()``. If no token was registered + /// nil is returned. + @inlinable + mutating func retryConnect() -> TimerCancellationToken? { + switch self.state { + case .backingOff(let timer): + self.state = .starting + return timer.cancellationContinuation + case .starting, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @usableFromInline + struct LeaseAction { + @usableFromInline + var connection: Connection + @usableFromInline + var timersToCancel: Max2Sequence + @usableFromInline + var wasIdle: Bool + + @inlinable + init(connection: Connection, timersToCancel: Max2Sequence, wasIdle: Bool) { + self.connection = connection + self.timersToCancel = timersToCancel + self.wasIdle = wasIdle + } + } + + @inlinable + mutating func lease(streams newLeasedStreams: UInt16 = 1) -> LeaseAction { + switch self.state { + case .idle(let connection, let maxStreams, var keepAlive, let idleTimer): + var cancel = Max2Sequence() + if let token = idleTimer?.cancellationContinuation { + cancel.append(token) + } + if let token = keepAlive.cancelTimerIfScheduled() { + cancel.append(token) + } + precondition(maxStreams >= newLeasedStreams + keepAlive.usedStreams, "Invalid state: \(self.state)") + self.state = .leased(connection, usedStreams: newLeasedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return LeaseAction(connection: connection, timersToCancel: cancel, wasIdle: true) + + case .leased(let connection, let usedStreams, let maxStreams, let keepAlive): + precondition(maxStreams >= usedStreams + newLeasedStreams + keepAlive.usedStreams, "Invalid state: \(self.state)") + self.state = .leased(connection, usedStreams: usedStreams + newLeasedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return LeaseAction(connection: connection, timersToCancel: .init(), wasIdle: false) + + case .backingOff, .starting, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func release(streams returnedStreams: UInt16) -> ConnectionAvailableInfo { + switch self.state { + case .leased(let connection, let usedStreams, let maxStreams, let keepAlive): + precondition(usedStreams >= returnedStreams) + let newUsedStreams = usedStreams - returnedStreams + let availableStreams = maxStreams - (newUsedStreams + keepAlive.usedStreams) + if newUsedStreams == 0 { + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: nil) + return .idle(availableStreams: availableStreams, newIdle: true) + } else { + self.state = .leased(connection, usedStreams: newUsedStreams, maxStreams: maxStreams, keepAlive: keepAlive) + return .leased(availableStreams: availableStreams) + } + case .backingOff, .starting, .idle, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func runKeepAliveIfIdle(reducesAvailableStreams: Bool) -> KeepAliveAction? { + switch self.state { + case .idle(let connection, let maxStreams, .scheduled(let timer), let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .running(reducesAvailableStreams), idleTimer: idleTimer) + return KeepAliveAction( + connection: connection, + keepAliveTimerCancellationContinuation: timer.cancellationContinuation + ) + + case .leased, .closed, .closing: + return nil + + case .backingOff, .starting, .idle(_, _, .running, _), .idle(_, _, .notScheduled, _): + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func keepAliveSucceeded() -> ConnectionAvailableInfo? { + switch self.state { + case .idle(let connection, let maxStreams, .running, let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .notScheduled, idleTimer: idleTimer) + return .idle(availableStreams: maxStreams, newIdle: false) + + case .leased(let connection, let usedStreams, let maxStreams, .running): + self.state = .leased(connection, usedStreams: usedStreams, maxStreams: maxStreams, keepAlive: .notScheduled) + return .leased(availableStreams: maxStreams - usedStreams) + + case .closed, .closing: + return nil + + case .backingOff, .starting, + .leased(_, _, _, .notScheduled), + .leased(_, _, _, .scheduled), + .idle(_, _, .notScheduled, _), + .idle(_, _, .scheduled, _): + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func timerScheduled( + _ timer: ConnectionTimer, + cancelContinuation: TimerCancellationToken + ) -> TimerCancellationToken? { + switch timer.usecase { + case .backoff: + switch self.state { + case .backingOff(var timerState): + if timerState.timerID == timer.timerID { + timerState.registerCancellationContinuation(cancelContinuation) + self.state = .backingOff(timerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .idle, .leased, .closing, .closed: + return cancelContinuation + } + + case .idleTimeout: + switch self.state { + case .idle(let connection, let maxStreams, let keepAlive, let idleTimerState): + if var idleTimerState = idleTimerState, idleTimerState.timerID == timer.timerID { + idleTimerState.registerCancellationContinuation(cancelContinuation) + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: idleTimerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .backingOff, .leased, .closing, .closed: + return cancelContinuation + } + + case .keepAlive: + switch self.state { + case .idle(let connection, let maxStreams, .scheduled(var keepAliveTimerState), let idleTimerState): + if keepAliveTimerState.timerID == timer.timerID { + keepAliveTimerState.registerCancellationContinuation(cancelContinuation) + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: .scheduled(keepAliveTimerState), idleTimer: idleTimerState) + return nil + } else { + return cancelContinuation + } + + case .starting, .backingOff, .leased, .closing, .closed, + .idle(_, _, .running, _), + .idle(_, _, .notScheduled, _): + return cancelContinuation + } + } + } + + @usableFromInline + struct CloseAction { + @usableFromInline + var connection: Connection + @usableFromInline + var cancelTimers: Max2Sequence + @usableFromInline + var maxStreams: UInt16 + + @inlinable + init(connection: Connection, cancelTimers: Max2Sequence, maxStreams: UInt16) { + self.connection = connection + self.cancelTimers = cancelTimers + self.maxStreams = maxStreams + } + } + + @inlinable + mutating func close() -> CloseAction { + switch self.state { + case .idle(let connection, let maxStreams, var keepAlive, let idleTimerState): + self.state = .closing(connection) + return CloseAction( + connection: connection, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled(), + idleTimerState?.cancellationContinuation + ), + maxStreams: maxStreams + ) + + case .backingOff, .starting, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @inlinable + mutating func closeIfIdle() -> CloseAction? { + switch self.state { + case .idle: + return self.close() + case .leased, .closed: + return nil + case .backingOff, .starting, .closing: + preconditionFailure("Invalid state: \(self.state)") + } + } + + @usableFromInline + struct ShutdownAction { + @usableFromInline + var connection: Connection? + @usableFromInline + var timersToCancel: Max2Sequence + @usableFromInline + var maxStreams: UInt16 + @usableFromInline + var usedStreams: UInt16 + + @inlinable + init( + connection: Connection? = nil, + timersToCancel: Max2Sequence = .init(), + maxStreams: UInt16 = 0, + usedStreams: UInt16 = 0 + ) { + self.connection = connection + self.timersToCancel = timersToCancel + self.maxStreams = maxStreams + self.usedStreams = usedStreams + } + } + } + + @usableFromInline + enum ConnectionAvailableInfo: Equatable { + case leased(availableStreams: UInt16) + case idle(availableStreams: UInt16, newIdle: Bool) + + @usableFromInline + var availableStreams: UInt16 { + switch self { + case .leased(let availableStreams): + return availableStreams + case .idle(let availableStreams, newIdle: _): + return availableStreams + } + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.KeepAliveAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.connection === rhs.connection && lhs.keepAliveTimerCancellationContinuation == rhs.keepAliveTimerCancellationContinuation + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionState.LeaseAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.wasIdle == rhs.wasIdle && lhs.connection === rhs.connection && lhs.timersToCancel == rhs.timersToCancel + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionState.CloseAction: Equatable where TimerCancellationToken: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.cancelTimers == rhs.cancelTimers && lhs.connection === rhs.connection && lhs.maxStreams == rhs.maxStreams + } +} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index a3962790..dc18784f 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -39,36 +39,55 @@ struct PoolStateMachine< RequestID, TimerCancellationToken > where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { + + @usableFromInline + struct ConnectionRequest: Equatable { + @usableFromInline var connectionID: ConnectionID + + @inlinable + init(connectionID: ConnectionID) { + self.connectionID = connectionID + } + } @usableFromInline - struct Timer: Hashable, Sendable { + enum ConnectionAction { @usableFromInline - enum Usecase: Sendable { - case backoff - case idleTimeout - case keepAlive + struct Shutdown { + @usableFromInline + var connections: [Connection] + @usableFromInline + var timersToCancel: [TimerCancellationToken] + + @inlinable + init() { + self.connections = [] + self.timersToCancel = [] + } } - @usableFromInline - var connectionID: ConnectionID + case scheduleTimers(Max2Sequence) + case makeConnection(ConnectionRequest, TimerCancellationToken?) + case runKeepAlive(Connection, TimerCancellationToken?) + case cancelTimers(Max2Sequence) + case closeConnection(Connection) + case shutdown(Shutdown) - @usableFromInline - var timerID: Int + case none + } + @usableFromInline + struct Timer: Hashable, Sendable { @usableFromInline - var duration: Duration + var underlying: ConnectionTimer @usableFromInline - var usecase: Usecase + var duration: Duration @inlinable - init(connectionID: ConnectionID, timerID: Int, duration: Duration, usecase: Usecase) { - self.connectionID = connectionID - self.timerID = timerID + init(_ connectionTimer: ConnectionTimer, duration: Duration) { + self.underlying = connectionTimer self.duration = duration - self.usecase = usecase } } - - } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift index 20434450..27035ee9 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockTimerCancellationToken.swift @@ -2,15 +2,17 @@ @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) struct MockTimerCancellationToken: Hashable, Sendable { - var connectionID: MockConnection.ID - var timerID: Int - var duration: Duration - var usecase: TestPoolStateMachine.Timer.Usecase + enum Backing: Hashable, Sendable { + case timer(TestPoolStateMachine.Timer) + case connectionTimer(TestPoolStateMachine.ConnectionTimer) + } + var backing: Backing init(_ timer: TestPoolStateMachine.Timer) { - self.connectionID = timer.connectionID - self.timerID = timer.timerID - self.duration = timer.duration - self.usecase = timer.usecase + self.backing = .timer(timer) + } + + init(_ timer: TestPoolStateMachine.ConnectionTimer) { + self.backing = .connectionTimer(timer) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift new file mode 100644 index 00000000..b1622d0d --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -0,0 +1,264 @@ +@testable import _ConnectionPoolModule +import XCTest + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_ConnectionStateTests: XCTestCase { + + typealias TestConnectionState = TestPoolStateMachine.ConnectionState + + func testStartupLeaseReleaseParkLease() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + XCTAssertEqual(state.id, connectionID) + XCTAssertEqual(state.isIdleOrRunningKeepAlive, false) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.isConnected, false) + XCTAssertEqual(state.isLeased, false) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(state.isIdleOrRunningKeepAlive, true) + XCTAssertEqual(state.isAvailable, true) + XCTAssertEqual(state.isConnected, true) + XCTAssertEqual(state.isLeased, false) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.isIdleOrRunningKeepAlive, false) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.isConnected, true) + XCTAssertEqual(state.isLeased, true) + + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssert( + parkResult.elementsEqual([ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ]) + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) + + let expectLeaseAction = TestConnectionState.LeaseAction( + connection: connection, + timersToCancel: [idleTimerCancellationToken, keepAliveTimerCancellationToken], + wasIdle: true + ) + XCTAssertEqual(state.lease(streams: 1), expectLeaseAction) + } + + func testStartupParkLeaseBeforeTimersRegistered() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssertEqual( + parkResult, + [ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ] + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken), keepAliveTimerCancellationToken) + XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken), idleTimerCancellationToken) + } + + func testStartupParkLeasePark() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + XCTAssert( + parkResult.elementsEqual([ + .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) + ]) + ) + + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + let initialKeepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let initialIdleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual( + state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true), + [ + .init(timerID: 2, connectionID: connectionID, usecase: .keepAlive), + .init(timerID: 3, connectionID: connectionID, usecase: .idleTimeout) + ] + ) + + XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: initialKeepAliveTimerCancellationToken), initialKeepAliveTimerCancellationToken) + XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: initialIdleTimerCancellationToken), initialIdleTimerCancellationToken) + } + + func testStartupFailed() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let firstBackoffTimer = state.failedToConnect() + let firstBackoffTimerCancellationToken = MockTimerCancellationToken(firstBackoffTimer) + XCTAssertNil(state.timerScheduled(firstBackoffTimer, cancelContinuation: firstBackoffTimerCancellationToken)) + XCTAssertEqual(state.retryConnect(), firstBackoffTimerCancellationToken) + + let secondBackoffTimer = state.failedToConnect() + let secondBackoffTimerCancellationToken = MockTimerCancellationToken(secondBackoffTimer) + XCTAssertNil(state.retryConnect()) + XCTAssertEqual( + state.timerScheduled(secondBackoffTimer, cancelContinuation: secondBackoffTimerCancellationToken), + secondBackoffTimerCancellationToken + ) + + let thirdBackoffTimer = state.failedToConnect() + let thirdBackoffTimerCancellationToken = MockTimerCancellationToken(thirdBackoffTimer) + XCTAssertNil(state.retryConnect()) + let forthBackoffTimer = state.failedToConnect() + let forthBackoffTimerCancellationToken = MockTimerCancellationToken(forthBackoffTimer) + XCTAssertEqual( + state.timerScheduled(thirdBackoffTimer, cancelContinuation: thirdBackoffTimerCancellationToken), + thirdBackoffTimerCancellationToken + ) + XCTAssertNil( + state.timerScheduled(forthBackoffTimer, cancelContinuation: forthBackoffTimerCancellationToken) + ) + XCTAssertEqual(state.retryConnect(), forthBackoffTimerCancellationToken) + + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + } + + func testLeaseMultipleStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [keepAliveTimerCancellationToken], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) + + XCTAssertEqual( + state.lease(streams: 40), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + + XCTAssertEqual( + state.lease(streams: 40), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + + XCTAssertEqual(state.release(streams: 1), .leased(availableStreams: 1)) + XCTAssertEqual(state.release(streams: 98), .leased(availableStreams: 99)) + XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 100, newIdle: true)) + } + + func testRunningKeepAliveReducesAvailableStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.runKeepAliveIfIdle(reducesAvailableStreams: true), + .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) + ) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 79)) + XCTAssertEqual(state.isAvailable, true) + XCTAssertEqual( + state.lease(streams: 79), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) + ) + XCTAssertEqual(state.isAvailable, false) + XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 1)) + XCTAssertEqual(state.isAvailable, true) + } + + func testRunningKeepAliveDoesNotReduceAvailableStreams() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) + guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + + XCTAssertEqual( + state.runKeepAliveIfIdle(reducesAvailableStreams: false), + .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) + ) + + XCTAssertEqual( + state.lease(streams: 30), + TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) + ) + + XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) + XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 80)) + } + + func testRunKeepAliveRacesAgainstIdleClose() { + let connectionID = 1 + var state = TestConnectionState(id: connectionID) + let connection = MockConnection(id: connectionID) + XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) + guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { + return XCTFail("Expected to get two timers") + } + + XCTAssertEqual(keepAliveTimer, .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive)) + XCTAssertEqual(idleTimer, .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout)) + + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) + + XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) + + XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], maxStreams: 1)) + XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) + + } +} From 17d3c80e7739c781254c1883bd9e8fd6c113b1c1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 23 Oct 2023 11:20:32 +0200 Subject: [PATCH 202/292] Add `PoolStateMachine.ConnectionGroup` (#425) (#426) --- .../ConnectionPoolModule/Max2Sequence.swift | 3 +- .../PoolStateMachine+ConnectionGroup.swift | 640 ++++++++++++++++++ .../PoolStateMachine+ConnectionState.swift | 218 ++++-- .../PoolStateMachine.swift | 2 +- ...oolStateMachine+ConnectionGroupTests.swift | 294 ++++++++ ...oolStateMachine+ConnectionStateTests.swift | 8 +- 6 files changed, 1113 insertions(+), 52 deletions(-) create mode 100644 Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift create mode 100644 Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift diff --git a/Sources/ConnectionPoolModule/Max2Sequence.swift b/Sources/ConnectionPoolModule/Max2Sequence.swift index 0feccd68..9b7d972b 100644 --- a/Sources/ConnectionPoolModule/Max2Sequence.swift +++ b/Sources/ConnectionPoolModule/Max2Sequence.swift @@ -95,8 +95,7 @@ extension Max2Sequence: ExpressibleByArrayLiteral { init(arrayLiteral elements: Element...) { precondition(elements.count <= 2) var iterator = elements.makeIterator() - self.first = iterator.next() - self.second = iterator.next() + self.init(iterator.next(), iterator.next()) } } diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift new file mode 100644 index 00000000..8ec99c7d --- /dev/null +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -0,0 +1,640 @@ +import Atomics + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + + @usableFromInline + struct LeaseResult { + @usableFromInline + var connection: Connection + @usableFromInline + var timersToCancel: Max2Sequence + @usableFromInline + var wasIdle: Bool + @usableFromInline + var use: ConnectionGroup.ConnectionUse + + @inlinable + init( + connection: Connection, + timersToCancel: Max2Sequence, + wasIdle: Bool, + use: ConnectionGroup.ConnectionUse + ) { + self.connection = connection + self.timersToCancel = timersToCancel + self.wasIdle = wasIdle + self.use = use + } + } + + @usableFromInline + struct ConnectionGroup: Sendable { + @usableFromInline + struct Stats: Hashable, Sendable { + @usableFromInline var connecting: UInt16 = 0 + @usableFromInline var backingOff: UInt16 = 0 + @usableFromInline var idle: UInt16 = 0 + @usableFromInline var leased: UInt16 = 0 + @usableFromInline var runningKeepAlive: UInt16 = 0 + @usableFromInline var closing: UInt16 = 0 + + @usableFromInline var availableStreams: UInt16 = 0 + @usableFromInline var leasedStreams: UInt16 = 0 + + @usableFromInline var soonAvailable: UInt16 { + self.connecting + self.backingOff + self.runningKeepAlive + } + + @usableFromInline var active: UInt16 { + self.idle + self.leased + self.connecting + self.backingOff + } + } + + /// The minimum number of connections + @usableFromInline + let minimumConcurrentConnections: Int + + /// The maximum number of preserved connections + @usableFromInline + let maximumConcurrentConnectionSoftLimit: Int + + /// The absolute maximum number of connections + @usableFromInline + let maximumConcurrentConnectionHardLimit: Int + + @usableFromInline + let keepAlive: Bool + + @usableFromInline + let keepAliveReducesAvailableStreams: Bool + + /// A connectionID generator. + @usableFromInline + let generator: ConnectionIDGenerator + + /// The connections states + @usableFromInline + private(set) var connections: [ConnectionState] + + @usableFromInline + private(set) var stats = Stats() + + @inlinable + init( + generator: ConnectionIDGenerator, + minimumConcurrentConnections: Int, + maximumConcurrentConnectionSoftLimit: Int, + maximumConcurrentConnectionHardLimit: Int, + keepAlive: Bool, + keepAliveReducesAvailableStreams: Bool + ) { + self.generator = generator + self.connections = [] + self.minimumConcurrentConnections = minimumConcurrentConnections + self.maximumConcurrentConnectionSoftLimit = maximumConcurrentConnectionSoftLimit + self.maximumConcurrentConnectionHardLimit = maximumConcurrentConnectionHardLimit + self.keepAlive = keepAlive + self.keepAliveReducesAvailableStreams = keepAliveReducesAvailableStreams + } + + var isEmpty: Bool { + self.connections.isEmpty + } + + @usableFromInline + var canGrow: Bool { + self.stats.active < self.maximumConcurrentConnectionHardLimit + } + + @usableFromInline + var soonAvailableConnections: UInt16 { + self.stats.soonAvailable + } + + // MARK: - Mutations - + + /// A connection's use. Is it persisted or an overflow connection? + @usableFromInline + enum ConnectionUse: Equatable { + case persisted + case demand + case overflow + } + + /// Information around an idle connection. + @usableFromInline + struct AvailableConnectionContext { + /// The connection's use. Either general purpose or for requests with `EventLoop` + /// requirements. + @usableFromInline + var use: ConnectionUse + + @usableFromInline + var info: ConnectionAvailableInfo + } + + /// Information around the failed/closed connection. + @usableFromInline + struct FailedConnectionContext { + /// Connections that are currently starting + @usableFromInline + var connectionsStarting: Int + + @inlinable + init(connectionsStarting: Int) { + self.connectionsStarting = connectionsStarting + } + } + + mutating func refillConnections() -> [ConnectionRequest] { + let existingConnections = self.stats.active + let missingConnection = self.minimumConcurrentConnections - Int(existingConnections) + guard missingConnection > 0 else { + return [] + } + + var requests = [ConnectionRequest]() + requests.reserveCapacity(missingConnection) + + for _ in 0.. ConnectionRequest? { + precondition(self.minimumConcurrentConnections <= self.stats.active) + guard self.maximumConcurrentConnectionSoftLimit > self.stats.active else { + return nil + } + return self.createNewConnection() + } + + @inlinable + mutating func createNewOverflowConnectionIfPossible() -> ConnectionRequest? { + precondition(self.maximumConcurrentConnectionSoftLimit <= self.stats.active) + guard self.maximumConcurrentConnectionHardLimit > self.stats.active else { + return nil + } + return self.createNewConnection() + } + + @inlinable + /*private*/ mutating func createNewConnection() -> ConnectionRequest { + precondition(self.canGrow) + self.stats.connecting += 1 + let connectionID = self.generator.next() + let connection = ConnectionState(id: connectionID) + self.connections.append(connection) + return ConnectionRequest(connectionID: connectionID) + } + + /// A new ``Connection`` was established. + /// + /// This will put the connection into the idle state. + /// + /// - Parameter connection: The new established connection. + /// - Returns: An index and an IdleConnectionContext to determine the next action for the now idle connection. + /// Call ``parkConnection(at:)``, ``leaseConnection(at:)`` or ``closeConnection(at:)`` + /// with the supplied index after this. + @inlinable + mutating func newConnectionEstablished(_ connection: Connection, maxStreams: UInt16) -> (Int, AvailableConnectionContext) { + guard let index = self.connections.firstIndex(where: { $0.id == connection.id }) else { + preconditionFailure("There is a new connection that we didn't request!") + } + self.stats.connecting -= 1 + self.stats.idle += 1 + self.stats.availableStreams += maxStreams + let connectionInfo = self.connections[index].connected(connection, maxStreams: maxStreams) + // TODO: If this is an overflow connection, but we are currently also creating a + // persisted connection, we might want to swap those. + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + @inlinable + mutating func backoffNextConnectionAttempt(_ connectionID: Connection.ID) -> ConnectionTimer { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("We tried to create a new connection that we know nothing about?") + } + + self.stats.connecting -= 1 + self.stats.backingOff += 1 + + return self.connections[index].failedToConnect() + } + + @usableFromInline + enum BackoffDoneAction { + case createConnection(ConnectionRequest, TimerCancellationToken?) + case cancelTimers(Max2Sequence) + } + + @inlinable + mutating func backoffDone(_ connectionID: Connection.ID, retry: Bool) -> BackoffDoneAction { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("We tried to create a new connection that we know nothing about?") + } + + self.stats.backingOff -= 1 + + if retry || self.stats.active < self.minimumConcurrentConnections { + self.stats.connecting += 1 + let backoffTimerCancellation = self.connections[index].retryConnect() + return .createConnection(.init(connectionID: connectionID), backoffTimerCancellation) + } + + let backoffTimerCancellation = self.connections[index].destroyBackingOffConnection() + var timerCancellations = Max2Sequence(backoffTimerCancellation) + + if let timerCancellationToken = self.swapForDeletion(index: index) { + timerCancellations.append(timerCancellationToken) + } + return .cancelTimers(timerCancellations) + } + + @inlinable + mutating func timerScheduled( + _ timer: ConnectionTimer, + cancelContinuation: TimerCancellationToken + ) -> TimerCancellationToken? { + guard let index = self.connections.firstIndex(where: { $0.id == timer.connectionID }) else { + return cancelContinuation + } + + return self.connections[index].timerScheduled(timer, cancelContinuation: cancelContinuation) + } + + // MARK: Leasing and releasing + + /// Lease a connection, if an idle connection is available. + /// + /// - Returns: A connection to execute a request on. + @inlinable + mutating func leaseConnection() -> LeaseResult? { + if self.stats.availableStreams == 0 { + return nil + } + + guard let index = self.findAvailableConnection() else { + preconditionFailure("Stats and actual count are of.") + } + + return self.leaseConnection(at: index, streams: 1) + } + + @usableFromInline + enum LeasedConnectionOrStartingCount { + case leasedConnection(LeaseResult) + case startingCount(UInt16) + } + + @inlinable + mutating func leaseConnectionOrSoonAvailableConnectionCount() -> LeasedConnectionOrStartingCount { + if let result = self.leaseConnection() { + return .leasedConnection(result) + } + return .startingCount(self.stats.soonAvailable) + } + + @inlinable + mutating func leaseConnection(at index: Int, streams: UInt16) -> LeaseResult { + let leaseResult = self.connections[index].lease(streams: streams) + let use = self.getConnectionUse(index: index) + + if leaseResult.wasIdle { + self.stats.idle -= 1 + self.stats.leased += 1 + } + self.stats.leasedStreams += streams + self.stats.availableStreams -= streams + return LeaseResult( + connection: leaseResult.connection, + timersToCancel: leaseResult.timersToCancel, + wasIdle: leaseResult.wasIdle, + use: use + ) + } + + @inlinable + mutating func parkConnection(at index: Int) -> Max2Sequence { + let scheduleIdleTimeoutTimer: Bool + switch index { + case 0.. (Int, AvailableConnectionContext) { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("A connection that we don't know was released? Something is very wrong...") + } + + let connectionInfo = self.connections[index].release(streams: streams) + self.stats.availableStreams += streams + self.stats.leasedStreams -= streams + switch connectionInfo { + case .idle: + self.stats.idle += 1 + self.stats.leased -= 1 + case .leased: + break + } + + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + @inlinable + mutating func keepAliveIfIdle(_ connectionID: Connection.ID) -> KeepAliveAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // because of a race this connection (connection close runs against trigger of ping pong) + // was already removed from the state machine. + return nil + } + + guard let action = self.connections[index].runKeepAliveIfIdle(reducesAvailableStreams: self.keepAliveReducesAvailableStreams) else { + return nil + } + + self.stats.runningKeepAlive += 1 + if self.keepAliveReducesAvailableStreams { + self.stats.availableStreams -= 1 + } + + return action + } + + @inlinable + mutating func keepAliveSucceeded(_ connectionID: Connection.ID) -> (Int, AvailableConnectionContext)? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + preconditionFailure("A connection that we don't know was released? Something is very wrong...") + } + + guard let connectionInfo = self.connections[index].keepAliveSucceeded() else { + // if we don't get connection info here this means, that the connection already was + // transitioned to closing. when we did this we already decremented the + // runningKeepAlive timer. + return nil + } + + self.stats.runningKeepAlive -= 1 + if self.keepAliveReducesAvailableStreams { + self.stats.availableStreams += 1 + } + + let context = self.makeAvailableConnectionContextForConnection(at: index, info: connectionInfo) + return (index, context) + } + + // MARK: Connection close/removal + + @usableFromInline + struct CloseAction { + @usableFromInline + private(set) var connection: Connection + + @usableFromInline + private(set) var timersToCancel: Max2Sequence + + @inlinable + init(connection: Connection, timersToCancel: Max2Sequence) { + self.connection = connection + self.timersToCancel = timersToCancel + } + } + + /// Closes the connection at the given index. + @inlinable + mutating func closeConnectionIfIdle(at index: Int) -> CloseAction { + guard let closeAction = self.connections[index].closeIfIdle() else { + preconditionFailure("Invalid state: \(self)") + } + + self.stats.idle -= 1 + self.stats.closing += 1 + +// if idleState.runningKeepAlive { +// self.stats.runningKeepAlive -= 1 +// if self.keepAliveReducesAvailableStreams { +// self.stats.availableStreams += 1 +// } +// } + + self.stats.availableStreams -= closeAction.maxStreams + + return CloseAction( + connection: closeAction.connection!, + timersToCancel: closeAction.cancelTimers + ) + } + + @inlinable + mutating func closeConnectionIfIdle(_ connectionID: Connection.ID) -> CloseAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // because of a race this connection (connection close runs against trigger of timeout) + // was already removed from the state machine. + return nil + } + + if index < self.minimumConcurrentConnections { + // because of a race a connection might receive a idle timeout after it was moved into + // the persisted connections. If a connection is now persisted, we now need to ignore + // the trigger + return nil + } + + return self.closeConnectionIfIdle(at: index) + } + + /// Connection closed. Call this method, if a connection is closed. + /// + /// This will put the position into the closed state. + /// + /// - Parameter connectionID: The failed connection's id. + /// - Returns: An optional index and an IdleConnectionContext to determine the next action for the closed connection. + /// You must call ``removeConnection(at:)`` or ``replaceConnection(at:)`` with the + /// supplied index after this. If nil is returned the connection was closed by the state machine and was + /// therefore already removed. + @inlinable + mutating func connectionClosed(_ connectionID: Connection.ID) -> FailedConnectionContext? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + return nil + } + + let closedAction = self.connections[index].closed() + + if closedAction.wasRunningKeepAlive { + self.stats.runningKeepAlive -= 1 + } + self.stats.leasedStreams -= closedAction.usedStreams + self.stats.availableStreams -= closedAction.maxStreams - closedAction.usedStreams + + switch closedAction.previousConnectionState { + case .idle: + self.stats.idle -= 1 + + case .leased: + self.stats.leased -= 1 + + case .closing: + self.stats.closing -= 1 + } + + let lastIndex = self.connections.index(before: self.connections.endIndex) + + if index == lastIndex { + self.connections.remove(at: index) + } else { + self.connections.swapAt(index, lastIndex) + self.connections.remove(at: lastIndex) + } + + return FailedConnectionContext(connectionsStarting: 0) + } + + // MARK: Shutdown + + mutating func triggerForceShutdown(_ cleanup: inout ConnectionAction.Shutdown) { + for var connectionState in self.connections { + guard let closeAction = connectionState.close() else { + continue + } + + if let connection = closeAction.connection { + cleanup.connections.append(connection) + } + cleanup.timersToCancel.append(contentsOf: closeAction.cancelTimers) + } + + self.connections = [] + } + + // MARK: - Private functions - + + @usableFromInline + /*private*/ func getConnectionUse(index: Int) -> ConnectionUse { + switch index { + case 0.. AvailableConnectionContext { + precondition(self.connections[index].isAvailable) + let use = self.getConnectionUse(index: index) + return AvailableConnectionContext(use: use, info: info) + } + + @inlinable + /*private*/ func findAvailableConnection() -> Int? { + return self.connections.firstIndex(where: { $0.isAvailable }) + } + + @inlinable + /*private*/ mutating func swapForDeletion(index indexToDelete: Int) -> TimerCancellationToken? { + let maybeLastConnectedIndex = self.connections.lastIndex(where: { $0.isConnected }) + + if maybeLastConnectedIndex == nil || maybeLastConnectedIndex! < indexToDelete { + self.removeO1(indexToDelete) + return nil + } + + // if maybeLastConnectedIndex == nil, we return early in the above if case. + let lastConnectedIndex = maybeLastConnectedIndex! + + switch indexToDelete { + case 0.. State.Timer { - defer { self.nextTimerID += 1 } - return State.Timer(id: self.nextTimerID) - } - /// The connection failed to start @inlinable mutating func failedToConnect() -> ConnectionTimer { switch self.state { case .starting: - let backoffTimerState = self.nextTimer() + let backoffTimerState = self._nextTimer() self.state = .backingOff(backoffTimerState) return ConnectionTimer(timerID: backoffTimerState.timerID, connectionID: self.id, usecase: .backoff) @@ -311,6 +295,17 @@ extension PoolStateMachine { } } + @inlinable + mutating func destroyBackingOffConnection() -> TimerCancellationToken? { + switch self.state { + case .backingOff(let timer): + self.state = .closed + return timer.cancellationContinuation + case .starting, .idle, .leased, .closing, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + @usableFromInline struct LeaseAction { @usableFromInline @@ -468,78 +463,211 @@ extension PoolStateMachine { } } + @inlinable + mutating func cancelIdleTimer() -> TimerCancellationToken? { + switch self.state { + case .starting, .backingOff, .leased, .closing, .closed: + return nil + + case .idle(let connection, let maxStreams, let keepAlive, let idleTimer): + self.state = .idle(connection, maxStreams: maxStreams, keepAlive: keepAlive, idleTimer: nil) + return idleTimer?.cancellationContinuation + } + } + @usableFromInline struct CloseAction { + @usableFromInline - var connection: Connection + enum PreviousConnectionState { + case idle + case leased + case closing + case backingOff + } + + @usableFromInline + var connection: Connection? + @usableFromInline + var previousConnectionState: PreviousConnectionState @usableFromInline var cancelTimers: Max2Sequence @usableFromInline + var usedStreams: UInt16 + @usableFromInline var maxStreams: UInt16 @inlinable - init(connection: Connection, cancelTimers: Max2Sequence, maxStreams: UInt16) { + init( + connection: Connection?, + previousConnectionState: PreviousConnectionState, + cancelTimers: Max2Sequence, + usedStreams: UInt16, + maxStreams: UInt16 + ) { self.connection = connection + self.previousConnectionState = previousConnectionState self.cancelTimers = cancelTimers + self.usedStreams = usedStreams self.maxStreams = maxStreams } } @inlinable - mutating func close() -> CloseAction { + mutating func closeIfIdle() -> CloseAction? { switch self.state { case .idle(let connection, let maxStreams, var keepAlive, let idleTimerState): self.state = .closing(connection) return CloseAction( connection: connection, + previousConnectionState: .idle, cancelTimers: Max2Sequence( keepAlive.cancelTimerIfScheduled(), idleTimerState?.cancellationContinuation ), + usedStreams: keepAlive.usedStreams, maxStreams: maxStreams ) - case .backingOff, .starting, .leased, .closing, .closed: + case .leased, .closed: + return nil + + case .backingOff, .starting, .closing: preconditionFailure("Invalid state: \(self.state)") } } @inlinable - mutating func closeIfIdle() -> CloseAction? { + mutating func close() -> CloseAction? { switch self.state { - case .idle: - return self.close() - case .leased, .closed: + case .starting: + // If we are currently starting, there is nothing we can do about it right now. + // Only once the connection has come up, or failed, we can actually act. return nil - case .backingOff, .starting, .closing: - preconditionFailure("Invalid state: \(self.state)") + + case .closing, .closed: + // If we are already closing, we can't do anything else. + return nil + + case .idle(let connection, let maxStreams, var keepAlive, let idleTimerState): + self.state = .closing(connection) + return CloseAction( + connection: connection, + previousConnectionState: .idle, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled(), + idleTimerState?.cancellationContinuation + ), + usedStreams: keepAlive.usedStreams, + maxStreams: maxStreams + ) + + case .leased(let connection, usedStreams: let usedStreams, maxStreams: let maxStreams, var keepAlive): + self.state = .closing(connection) + return CloseAction( + connection: connection, + previousConnectionState: .leased, + cancelTimers: Max2Sequence( + keepAlive.cancelTimerIfScheduled() + ), + usedStreams: keepAlive.usedStreams + usedStreams, + maxStreams: maxStreams + ) + + case .backingOff(let timer): + self.state = .closed + return CloseAction( + connection: nil, + previousConnectionState: .backingOff, + cancelTimers: Max2Sequence(timer.cancellationContinuation), + usedStreams: 0, + maxStreams: 0 + ) } } @usableFromInline - struct ShutdownAction { + struct ClosedAction { + @usableFromInline - var connection: Connection? + enum PreviousConnectionState { + case idle + case leased + case closing + } + @usableFromInline - var timersToCancel: Max2Sequence + var previousConnectionState: PreviousConnectionState + @usableFromInline + var cancelTimers: Max2Sequence @usableFromInline var maxStreams: UInt16 @usableFromInline var usedStreams: UInt16 + @usableFromInline + var wasRunningKeepAlive: Bool @inlinable init( - connection: Connection? = nil, - timersToCancel: Max2Sequence = .init(), - maxStreams: UInt16 = 0, - usedStreams: UInt16 = 0 + previousConnectionState: PreviousConnectionState, + cancelTimers: Max2Sequence, + maxStreams: UInt16, + usedStreams: UInt16, + wasRunningKeepAlive: Bool ) { - self.connection = connection - self.timersToCancel = timersToCancel + self.previousConnectionState = previousConnectionState + self.cancelTimers = cancelTimers self.maxStreams = maxStreams self.usedStreams = usedStreams + self.wasRunningKeepAlive = wasRunningKeepAlive + } + } + + @inlinable + mutating func closed() -> ClosedAction { + switch self.state { + case .starting, .backingOff, .closed: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(_, let maxStreams, var keepAlive, let idleTimer): + self.state = .closed + return ClosedAction( + previousConnectionState: .idle, + cancelTimers: .init(keepAlive.cancelTimerIfScheduled(), idleTimer?.cancellationContinuation), + maxStreams: maxStreams, + usedStreams: keepAlive.usedStreams, + wasRunningKeepAlive: keepAlive.isRunning + ) + + case .leased(_, let usedStreams, let maxStreams, let keepAlive): + self.state = .closed + return ClosedAction( + previousConnectionState: .leased, + cancelTimers: .init(), + maxStreams: maxStreams, + usedStreams: usedStreams + keepAlive.usedStreams, + wasRunningKeepAlive: keepAlive.isRunning + ) + + case .closing: + self.state = .closed + return ClosedAction( + previousConnectionState: .closing, + cancelTimers: .init(), + maxStreams: 0, + usedStreams: 0, + wasRunningKeepAlive: false + ) } } + + // MARK: - Private Methods - + + @inlinable + mutating /*private*/ func _nextTimer() -> State.Timer { + defer { self.nextTimerID += 1 } + return State.Timer(id: self.nextTimerID) + } } @usableFromInline diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index dc18784f..29349e56 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -37,7 +37,7 @@ struct PoolStateMachine< ConnectionID: Hashable & Sendable, Request: ConnectionRequestProtocol, RequestID, - TimerCancellationToken + TimerCancellationToken: Sendable > where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { @usableFromInline diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift new file mode 100644 index 00000000..4e3a1647 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -0,0 +1,294 @@ +import XCTest +@testable import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachine_ConnectionGroupTests: XCTestCase { + var idGenerator: ConnectionIDGenerator! + + override func setUp() { + self.idGenerator = ConnectionIDGenerator() + super.setUp() + } + + override func tearDown() { + self.idGenerator = nil + super.tearDown() + } + + func testRefillConnections() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 4, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + XCTAssertTrue(connections.isEmpty) + let requests = connections.refillConnections() + XCTAssertFalse(connections.isEmpty) + + XCTAssertEqual(requests.count, 4) + XCTAssertNil(connections.createNewDemandConnectionIfPossible()) + XCTAssertNil(connections.createNewOverflowConnectionIfPossible()) + XCTAssertEqual(connections.stats, .init(connecting: 4)) + XCTAssertEqual(connections.soonAvailableConnections, 4) + + let requests2 = connections.refillConnections() + XCTAssertTrue(requests2.isEmpty) + + var connected: UInt16 = 0 + for request in requests { + let newConnection = MockConnection(id: request.connectionID) + let (_, context) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(context.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(context.use, .persisted) + connected += 1 + XCTAssertEqual(connections.stats, .init(connecting: 4 - connected, idle: connected, availableStreams: connected)) + XCTAssertEqual(connections.soonAvailableConnections, 4 - connected) + } + + let requests3 = connections.refillConnections() + XCTAssertTrue(requests3.isEmpty) + } + + func testMakeConnectionLeaseItAndDropItHappyPath() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertTrue(connections.isEmpty) + XCTAssertTrue(requests.isEmpty) + + guard let request = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to receive a connection request") + } + XCTAssertEqual(request, .init(connectionID: 0)) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(connections.soonAvailableConnections, 1) + XCTAssertEqual(connections.stats, .init(connecting: 1)) + + let newConnection = MockConnection(id: request.connectionID) + let (_, establishedContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 0) + + guard case .leasedConnection(let leaseResult) = connections.leaseConnectionOrSoonAvailableConnectionCount() else { + return XCTFail("Expected to lease a connection") + } + XCTAssert(newConnection === leaseResult.connection) + XCTAssertEqual(connections.stats, .init(leased: 1, leasedStreams: 1)) + + let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) + XCTAssertEqual(releasedContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(releasedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + let parkTimers = connections.parkConnection(at: index) + XCTAssertEqual(parkTimers, [ + .init(timerID: 0, connectionID: newConnection.id, usecase: .keepAlive), + .init(timerID: 1, connectionID: newConnection.id, usecase: .idleTimeout), + ]) + + guard let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) else { + return XCTFail("Expected to get a connection for ping pong") + } + XCTAssert(newConnection === keepAliveAction.connection) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + guard let (_, pingPongContext) = connections.keepAliveSucceeded(newConnection.id) else { + return XCTFail("Expected to get an AvailableContext") + } + XCTAssertEqual(pingPongContext.info, .idle(availableStreams: 1, newIdle: false)) + XCTAssertEqual(releasedContext.use, .demand) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + guard let closeAction = connections.closeConnectionIfIdle(newConnection.id) else { + return XCTFail("Expected to get a connection for ping pong") + } + XCTAssertEqual(closeAction.timersToCancel, []) + XCTAssert(closeAction.connection === newConnection) + XCTAssertEqual(connections.stats, .init(closing: 1, availableStreams: 0)) + + let closeContext = connections.connectionClosed(newConnection.id) + XCTAssertEqual(closeContext?.connectionsStarting, 0) + XCTAssertTrue(connections.isEmpty) + XCTAssertEqual(connections.stats, .init()) + } + + func testBackoffDoneCreatesANewConnectionToReachMinimumConnectionsEvenThoughRetryIsSetToFalse() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 1, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertEqual(connections.stats, .init(connecting: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(requests.count, 1) + + guard let request = requests.first else { return XCTFail("Expected to receive a connection request") } + XCTAssertEqual(request, .init(connectionID: 0)) + + let backoffTimer = connections.backoffNextConnectionAttempt(request.connectionID) + XCTAssertEqual(connections.stats, .init(backingOff: 1)) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + + let backoffDoneAction = connections.backoffDone(request.connectionID, retry: false) + XCTAssertEqual(backoffDoneAction, .createConnection(.init(connectionID: 0), backoffTimerCancellationToken)) + + XCTAssertEqual(connections.stats, .init(connecting: 1)) + } + + func testBackoffDoneCancelsIdleTimerIfAPersistedConnectionIsNotRetried() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 2, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertEqual(connections.stats, .init(connecting: 2)) + XCTAssertEqual(connections.soonAvailableConnections, 2) + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(requests.count, 2) + + var requestIterator = requests.makeIterator() + guard let firstRequest = requestIterator.next(), let secondRequest = requestIterator.next() else { + return XCTFail("Expected to get two requests") + } + + guard let thirdRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get another request") + } + XCTAssertEqual(connections.stats, .init(connecting: 3)) + + let newSecondConnection = MockConnection(id: secondRequest.connectionID) + let (_, establishedSecondConnectionContext) = connections.newConnectionEstablished(newSecondConnection, maxStreams: 1) + XCTAssertEqual(establishedSecondConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedSecondConnectionContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(connecting: 2, idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 2) + + let newThirdConnection = MockConnection(id: thirdRequest.connectionID) + let (thirdConnectionIndex, establishedThirdConnectionContext) = connections.newConnectionEstablished(newThirdConnection, maxStreams: 1) + XCTAssertEqual(establishedThirdConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedThirdConnectionContext.use, .demand) + XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 2, availableStreams: 2)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + let thirdConnKeepTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: thirdRequest.connectionID, usecase: .keepAlive) + let thirdConnIdleTimer = TestPoolStateMachine.ConnectionTimer(timerID: 1, connectionID: thirdRequest.connectionID, usecase: .idleTimeout) + let thirdConnIdleTimerCancellationToken = MockTimerCancellationToken(thirdConnIdleTimer) + XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex), [thirdConnKeepTimer, thirdConnIdleTimer]) + + XCTAssertNil(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer))) + XCTAssertNil(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken)) + + let backoffTimer = connections.backoffNextConnectionAttempt(firstRequest.connectionID) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + + // connection three should be moved to connection one and for this reason become permanent + + XCTAssertEqual(connections.backoffDone(firstRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken, thirdConnIdleTimerCancellationToken])) + XCTAssertEqual(connections.stats, .init(idle: 2, availableStreams: 2)) + + XCTAssertNil(connections.closeConnectionIfIdle(newThirdConnection.id)) + } + + func testBackoffDoneReturnsNilIfOverflowConnection() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get two requests") + } + + guard let secondRequest = connections.createNewDemandConnectionIfPossible() else { + return XCTFail("Expected to get another request") + } + XCTAssertEqual(connections.stats, .init(connecting: 2)) + + let newFirstConnection = MockConnection(id: firstRequest.connectionID) + let (_, establishedFirstConnectionContext) = connections.newConnectionEstablished(newFirstConnection, maxStreams: 1) + XCTAssertEqual(establishedFirstConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedFirstConnectionContext.use, .demand) + XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 1, availableStreams: 1)) + XCTAssertEqual(connections.soonAvailableConnections, 1) + + let backoffTimer = connections.backoffNextConnectionAttempt(secondRequest.connectionID) + let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) + XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 1, availableStreams: 1)) + XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + + XCTAssertEqual(connections.backoffDone(secondRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken])) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + + XCTAssertNotNil(connections.closeConnectionIfIdle(newFirstConnection.id)) + } + + func testPingPong() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 1, + maximumConcurrentConnectionSoftLimit: 4, + maximumConcurrentConnectionHardLimit: 4, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + let requests = connections.refillConnections() + XCTAssertFalse(connections.isEmpty) + XCTAssertEqual(connections.stats, .init(connecting: 1)) + + XCTAssertEqual(requests.count, 1) + guard let firstRequest = requests.first else { return XCTFail("Expected to have a request here") } + + let newConnection = MockConnection(id: firstRequest.connectionID) + let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(establishedConnectionContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + let timers = connections.parkConnection(at: connectionIndex) + let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertEqual(timers, [keepAliveTimer]) + XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) + XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + guard let (_, afterPingIdleContext) = connections.keepAliveSucceeded(newConnection.id) else { + return XCTFail("Expected to receive an AvailableContext") + } + XCTAssertEqual(afterPingIdleContext.info, .idle(availableStreams: 1, newIdle: false)) + XCTAssertEqual(afterPingIdleContext.use, .persisted) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index b1622d0d..7751837e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -10,19 +10,19 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { let connectionID = 1 var state = TestConnectionState(id: connectionID) XCTAssertEqual(state.id, connectionID) - XCTAssertEqual(state.isIdleOrRunningKeepAlive, false) + XCTAssertEqual(state.isIdle, false) XCTAssertEqual(state.isAvailable, false) XCTAssertEqual(state.isConnected, false) XCTAssertEqual(state.isLeased, false) let connection = MockConnection(id: connectionID) XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(state.isIdleOrRunningKeepAlive, true) + XCTAssertEqual(state.isIdle, true) XCTAssertEqual(state.isAvailable, true) XCTAssertEqual(state.isConnected, true) XCTAssertEqual(state.isLeased, false) XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) - XCTAssertEqual(state.isIdleOrRunningKeepAlive, false) + XCTAssertEqual(state.isIdle, false) XCTAssertEqual(state.isAvailable, false) XCTAssertEqual(state.isConnected, true) XCTAssertEqual(state.isLeased, true) @@ -257,7 +257,7 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) - XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], maxStreams: 1)) + XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1)) XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) } From 472ff4ae68bd9b8d59d978137812137ee8162f4a Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 25 Oct 2023 22:44:46 +0200 Subject: [PATCH 203/292] Add `PoolStateMachine` (#427) --- .../PoolStateMachine+ConnectionGroup.swift | 61 ++- .../PoolStateMachine+RequestQueue.swift | 10 +- .../PoolStateMachine.swift | 484 +++++++++++++++++- ...tSequence.swift => TinyFastSequence.swift} | 80 ++- ...oolStateMachine+ConnectionGroupTests.swift | 2 +- .../PoolStateMachineTests.swift | 217 +++++++- ...tSequence.swift => TinyFastSequence.swift} | 16 +- 7 files changed, 814 insertions(+), 56 deletions(-) rename Sources/ConnectionPoolModule/{OneElementFastSequence.swift => TinyFastSequence.swift} (58%) rename Tests/ConnectionPoolModuleTests/{OneElementFastSequence.swift => TinyFastSequence.swift} (82%) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 8ec99c7d..16970599 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -134,19 +134,6 @@ extension PoolStateMachine { var info: ConnectionAvailableInfo } - /// Information around the failed/closed connection. - @usableFromInline - struct FailedConnectionContext { - /// Connections that are currently starting - @usableFromInline - var connectionsStarting: Int - - @inlinable - init(connectionsStarting: Int) { - self.connectionsStarting = connectionsStarting - } - } - mutating func refillConnections() -> [ConnectionRequest] { let existingConnections = self.stats.active let missingConnection = self.minimumConcurrentConnections - Int(existingConnections) @@ -477,6 +464,31 @@ extension PoolStateMachine { return self.closeConnectionIfIdle(at: index) } + /// Information around the failed/closed connection. + @usableFromInline + struct ClosedAction { + /// Connections that are currently starting + @usableFromInline + var connectionsStarting: Int + + @usableFromInline + var timersToCancel: TinyFastSequence + + @usableFromInline + var newConnectionRequest: ConnectionRequest? + + @inlinable + init( + connectionsStarting: Int, + timersToCancel: TinyFastSequence, + newConnectionRequest: ConnectionRequest? = nil + ) { + self.connectionsStarting = connectionsStarting + self.timersToCancel = timersToCancel + self.newConnectionRequest = newConnectionRequest + } + } + /// Connection closed. Call this method, if a connection is closed. /// /// This will put the position into the closed state. @@ -487,12 +499,13 @@ extension PoolStateMachine { /// supplied index after this. If nil is returned the connection was closed by the state machine and was /// therefore already removed. @inlinable - mutating func connectionClosed(_ connectionID: Connection.ID) -> FailedConnectionContext? { + mutating func connectionClosed(_ connectionID: Connection.ID) -> ClosedAction { guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { - return nil + preconditionFailure("All connections that have been created should say goodbye exactly once!") } let closedAction = self.connections[index].closed() + var timersToCancel = TinyFastSequence(closedAction.cancelTimers) if closedAction.wasRunningKeepAlive { self.stats.runningKeepAlive -= 1 @@ -511,16 +524,22 @@ extension PoolStateMachine { self.stats.closing -= 1 } - let lastIndex = self.connections.index(before: self.connections.endIndex) + if let cancellationTimer = self.swapForDeletion(index: index) { + timersToCancel.append(cancellationTimer) + } - if index == lastIndex { - self.connections.remove(at: index) + let newConnectionRequest: ConnectionRequest? + if self.connections.count < self.minimumConcurrentConnections { + newConnectionRequest = .init(connectionID: self.generator.next()) } else { - self.connections.swapAt(index, lastIndex) - self.connections.remove(at: lastIndex) + newConnectionRequest = .none } - return FailedConnectionContext(connectionsStarting: 0) + return ClosedAction( + connectionsStarting: 0, + timersToCancel: timersToCancel, + newConnectionRequest: newConnectionRequest + ) } // MARK: Shutdown diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift index 7e3c6607..f1d6f4e4 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift @@ -10,7 +10,7 @@ extension PoolStateMachine { /// request from the dictionary and keep it inside the queue. Whenever we pop a request from the deque, we validate /// that it hasn't been cancelled in the meantime by checking if the popped request is still in the `requests` dictionary. @usableFromInline - struct RequestQueue { + struct RequestQueue: Sendable { @usableFromInline private(set) var queue: Deque @@ -40,8 +40,8 @@ extension PoolStateMachine { } @inlinable - mutating func pop(max: UInt16) -> OneElementFastSequence { - var result = OneElementFastSequence() + mutating func pop(max: UInt16) -> TinyFastSequence { + var result = TinyFastSequence() result.reserveCapacity(Int(max)) var popped = 0 while let requestID = self.queue.popFirst(), popped < max { @@ -61,8 +61,8 @@ extension PoolStateMachine { } @inlinable - mutating func removeAll() -> OneElementFastSequence { - let result = OneElementFastSequence(self.requests.values) + mutating func removeAll() -> TinyFastSequence { + let result = TinyFastSequence(self.requests.values) self.requests.removeAll() self.queue.removeAll() return result diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 29349e56..aa62d749 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -6,7 +6,7 @@ import Glibc @usableFromInline @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -struct PoolConfiguration { +struct PoolConfiguration: Sendable { /// The minimum number of connections to preserve in the pool. /// /// If the pool is mostly idle and the remote servers closes idle connections, @@ -38,10 +38,10 @@ struct PoolStateMachine< Request: ConnectionRequestProtocol, RequestID, TimerCancellationToken: Sendable -> where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { - +>: Sendable where Connection.ID == ConnectionID, ConnectionIDGenerator.ID == ConnectionID, RequestID == Request.ID { + @usableFromInline - struct ConnectionRequest: Equatable { + struct ConnectionRequest: Hashable, Sendable { @usableFromInline var connectionID: ConnectionID @inlinable @@ -50,6 +50,21 @@ struct PoolStateMachine< } } + @usableFromInline + struct Action { + @usableFromInline let request: RequestAction + @usableFromInline let connection: ConnectionAction + + @inlinable + init(request: RequestAction, connection: ConnectionAction) { + self.request = request + self.connection = connection + } + + @inlinable + static func none() -> Action { Action(request: .none, connection: .none) } + } + @usableFromInline enum ConnectionAction { @usableFromInline @@ -67,15 +82,32 @@ struct PoolStateMachine< } case scheduleTimers(Max2Sequence) - case makeConnection(ConnectionRequest, TimerCancellationToken?) + case makeConnection(ConnectionRequest, TinyFastSequence) case runKeepAlive(Connection, TimerCancellationToken?) - case cancelTimers(Max2Sequence) - case closeConnection(Connection) + case cancelTimers(TinyFastSequence) + case closeConnection(Connection, Max2Sequence) case shutdown(Shutdown) case none } + @usableFromInline + enum RequestAction { + case leaseConnection(TinyFastSequence, Connection) + + case failRequest(Request, ConnectionPoolError) + case failRequests(TinyFastSequence, ConnectionPoolError) + + case none + } + + @usableFromInline + enum PoolState: Sendable { + case running + case shuttingDown(graceful: Bool) + case shutDown + } + @usableFromInline struct Timer: Hashable, Sendable { @usableFromInline @@ -84,10 +116,448 @@ struct PoolStateMachine< @usableFromInline var duration: Duration + @inlinable + var connectionID: ConnectionID { + self.underlying.connectionID + } + @inlinable init(_ connectionTimer: ConnectionTimer, duration: Duration) { self.underlying = connectionTimer self.duration = duration } } + + @usableFromInline let configuration: PoolConfiguration + @usableFromInline let generator: ConnectionIDGenerator + + @usableFromInline + private(set) var connections: ConnectionGroup + @usableFromInline + private(set) var requestQueue: RequestQueue + @usableFromInline + private(set) var poolState: PoolState = .running + @usableFromInline + private(set) var cacheNoMoreConnectionsAllowed: Bool = false + + @usableFromInline + private(set) var failedConsecutiveConnectionAttempts: Int = 0 + + @inlinable + init( + configuration: PoolConfiguration, + generator: ConnectionIDGenerator, + timerCancellationTokenType: TimerCancellationToken.Type + ) { + self.configuration = configuration + self.generator = generator + self.connections = ConnectionGroup( + generator: generator, + minimumConcurrentConnections: configuration.minimumConnectionCount, + maximumConcurrentConnectionSoftLimit: configuration.maximumConnectionSoftLimit, + maximumConcurrentConnectionHardLimit: configuration.maximumConnectionHardLimit, + keepAlive: configuration.keepAliveDuration != nil, + keepAliveReducesAvailableStreams: true + ) + self.requestQueue = RequestQueue() + } + + mutating func refillConnections() -> [ConnectionRequest] { + return self.connections.refillConnections() + } + + @inlinable + mutating func leaseConnection(_ request: Request) -> Action { + switch self.poolState { + case .running: + break + + case .shuttingDown, .shutDown: + return .init( + request: .failRequest(request, ConnectionPoolError.poolShutdown), + connection: .none + ) + } + + if !self.requestQueue.isEmpty && self.cacheNoMoreConnectionsAllowed { + self.requestQueue.queue(request) + return .none() + } + + var soonAvailable: UInt16 = 0 + + // check if any other EL has an idle connection + switch self.connections.leaseConnectionOrSoonAvailableConnectionCount() { + case .leasedConnection(let leaseResult): + return .init( + request: .leaseConnection(TinyFastSequence(element: request), leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + + case .startingCount(let count): + soonAvailable += count + } + + // we tried everything. there is no connection available. now we must check, if and where we + // can create further connections. but first we must enqueue the new request + + self.requestQueue.queue(request) + + let requestAction = RequestAction.none + + if soonAvailable >= self.requestQueue.count { + // if more connections will be soon available then we have waiters, we don't need to + // create further new connections. + return .init( + request: requestAction, + connection: .none + ) + } else if let request = self.connections.createNewDemandConnectionIfPossible() { + // Can we create a demand connection + return .init( + request: requestAction, + connection: .makeConnection(request, .init()) + ) + } else if let request = self.connections.createNewOverflowConnectionIfPossible() { + // Can we create an overflow connection + return .init( + request: requestAction, + connection: .makeConnection(request, .init()) + ) + } else { + self.cacheNoMoreConnectionsAllowed = true + + // no new connections allowed: + return .init(request: requestAction, connection: .none) + } + } + + @inlinable + mutating func releaseConnection(_ connection: Connection, streams: UInt16) -> Action { + let (index, context) = self.connections.releaseConnection(connection.id, streams: streams) + return self.handleAvailableConnection(index: index, availableContext: context) + } + + mutating func cancelRequest(id: RequestID) -> Action { + guard let request = self.requestQueue.remove(id) else { + return .none() + } + + return .init( + request: .failRequest(request, ConnectionPoolError.requestCancelled), + connection: .none + ) + } + + @inlinable + mutating func connectionEstablished(_ connection: Connection, maxStreams: UInt16) -> Action { + let (index, context) = self.connections.newConnectionEstablished(connection, maxStreams: maxStreams) + return self.handleAvailableConnection(index: index, availableContext: context) + } + + @inlinable + mutating func timerScheduled(_ timer: Timer, cancelContinuation: TimerCancellationToken) -> TimerCancellationToken? { + self.connections.timerScheduled(timer.underlying, cancelContinuation: cancelContinuation) + } + + @inlinable + mutating func timerTriggered(_ timer: Timer) -> Action { + switch timer.underlying.usecase { + case .backoff: + return self.connectionCreationBackoffDone(timer.connectionID) + case .keepAlive: + return self.connectionKeepAliveTimerTriggered(timer.connectionID) + case .idleTimeout: + return self.connectionIdleTimerTriggered(timer.connectionID) + } + } + + @inlinable + mutating func connectionEstablishFailed(_ error: Error, for request: ConnectionRequest) -> Action { + self.failedConsecutiveConnectionAttempts += 1 + + let connectionTimer = self.connections.backoffNextConnectionAttempt(request.connectionID) + let backoff = Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + let timer = Timer(connectionTimer, duration: backoff) + return .init(request: .none, connection: .scheduleTimers(.init(timer))) + } + + @inlinable + mutating func connectionCreationBackoffDone(_ connectionID: ConnectionID) -> Action { + let soonAvailable = self.connections.soonAvailableConnections + let retry = (soonAvailable - 1) < self.requestQueue.count + + switch self.connections.backoffDone(connectionID, retry: retry) { + case .createConnection(let request, let continuation): + let timers: TinyFastSequence + if let continuation { + timers = .init(element: continuation) + } else { + timers = .init() + } + return .init(request: .none, connection: .makeConnection(request, timers)) + + case .cancelTimers(let timers): + return .init(request: .none, connection: .cancelTimers(.init(timers))) + } + } + + @inlinable + mutating func connectionKeepAliveTimerTriggered(_ connectionID: ConnectionID) -> Action { + precondition(self.configuration.keepAliveDuration != nil) + precondition(self.requestQueue.isEmpty) + + guard let keepAliveAction = self.connections.keepAliveIfIdle(connectionID) else { + return .none() + } + return .init(request: .none, connection: .runKeepAlive(keepAliveAction.connection, keepAliveAction.keepAliveTimerCancellationContinuation)) + } + + @inlinable + mutating func connectionKeepAliveDone(_ connection: Connection) -> Action { + precondition(self.configuration.keepAliveDuration != nil) + guard let (index, context) = self.connections.keepAliveSucceeded(connection.id) else { + return .none() + } + return self.handleAvailableConnection(index: index, availableContext: context) + } + + @inlinable + mutating func connectionIdleTimerTriggered(_ connectionID: ConnectionID) -> Action { + precondition(self.requestQueue.isEmpty) + + guard let closeAction = self.connections.closeConnectionIfIdle(connectionID) else { + return .none() + } + + self.cacheNoMoreConnectionsAllowed = false + return .init(request: .none, connection: .closeConnection(closeAction.connection, closeAction.timersToCancel)) + } + + @inlinable + mutating func connectionClosed(_ connection: Connection) -> Action { + self.cacheNoMoreConnectionsAllowed = false + + let closedConnectionAction = self.connections.connectionClosed(connection.id) + + let connectionAction: ConnectionAction + if let newRequest = closedConnectionAction.newConnectionRequest { + connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel) + } else { + connectionAction = .cancelTimers(closedConnectionAction.timersToCancel) + } + + return .init(request: .none, connection: connectionAction) + } + + struct CleanupAction { + struct ConnectionToDrop { + var connection: Connection + var keepAliveTimer: Bool + var idleTimer: Bool + } + + var connections: [ConnectionToDrop] + var requests: [Request] + } + + mutating func triggerGracefulShutdown() -> Action { + fatalError("Unimplemented") + } + + mutating func triggerForceShutdown() -> Action { + switch self.poolState { + case .running: + self.poolState = .shuttingDown(graceful: false) + var shutdown = ConnectionAction.Shutdown() + self.connections.triggerForceShutdown(&shutdown) + + if shutdown.connections.isEmpty { + self.poolState = .shutDown + } + + return .init( + request: .failRequests(self.requestQueue.removeAll(), ConnectionPoolError.poolShutdown), + connection: .shutdown(shutdown) + ) + + case .shuttingDown: + return .none() + + case .shutDown: + return .init(request: .none, connection: .none) + } + } + + @inlinable + /*private*/ mutating func handleAvailableConnection( + index: Int, + availableContext: ConnectionGroup.AvailableConnectionContext + ) -> Action { + // this connection was busy before + let requests = self.requestQueue.pop(max: availableContext.info.availableStreams) + if !requests.isEmpty { + let leaseResult = self.connections.leaseConnection(at: index, streams: UInt16(requests.count)) + return .init( + request: .leaseConnection(requests, leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + } + + switch availableContext.use { + case .persisted, .demand: + switch availableContext.info { + case .leased: + return .none() + + case .idle: + let timers = self.connections.parkConnection(at: index).map(self.mapTimers) + + return .init( + request: .none, + connection: .scheduleTimers(timers) + ) + } + + case .overflow: + let closeAction = self.connections.closeConnectionIfIdle(at: index) + return .init( + request: .none, + connection: .closeConnection(closeAction.connection, closeAction.timersToCancel) + ) + } + + } + + @inlinable + /* private */ func mapTimers(_ connectionTimer: ConnectionTimer) -> Timer { + switch connectionTimer.usecase { + case .backoff: + return Timer( + connectionTimer, + duration: Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + ) + + case .keepAlive: + return Timer(connectionTimer, duration: self.configuration.keepAliveDuration!) + + case .idleTimeout: + return Timer(connectionTimer, duration: self.configuration.idleTimeoutDuration) + + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine { + /// Calculates the delay for the next connection attempt after the given number of failed `attempts`. + /// + /// Our backoff formula is: 100ms * 1.25^(attempts - 1) with 3% jitter that is capped of at 1 minute. + /// This means for: + /// - 1 failed attempt : 100ms + /// - 5 failed attempts: ~300ms + /// - 10 failed attempts: ~930ms + /// - 15 failed attempts: ~2.84s + /// - 20 failed attempts: ~8.67s + /// - 25 failed attempts: ~26s + /// - 29 failed attempts: ~60s (max out) + /// + /// - Parameter attempts: number of failed attempts in a row + /// - Returns: time to wait until trying to establishing a new connection + @usableFromInline + static func calculateBackoff(failedAttempt attempts: Int) -> Duration { + // Our backoff formula is: 100ms * 1.25^(attempts - 1) that is capped of at 1minute + // This means for: + // - 1 failed attempt : 100ms + // - 5 failed attempts: ~300ms + // - 10 failed attempts: ~930ms + // - 15 failed attempts: ~2.84s + // - 20 failed attempts: ~8.67s + // - 25 failed attempts: ~26s + // - 29 failed attempts: ~60s (max out) + + let start = Double(100_000_000) + let backoffNanosecondsDouble = start * pow(1.25, Double(attempts - 1)) + + // Cap to 60s _before_ we convert to Int64, to avoid trapping in the Int64 initializer. + let backoffNanoseconds = Int64(min(backoffNanosecondsDouble, Double(60_000_000_000))) + + let backoff = Duration.nanoseconds(backoffNanoseconds) + + // Calculate a 3% jitter range + let jitterRange = (backoffNanoseconds / 100) * 3 + // Pick a random element from the range +/- jitter range. + let jitter: Duration = .nanoseconds((-jitterRange...jitterRange).randomElement()!) + let jitteredBackoff = backoff + jitter + return jitteredBackoff + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.Action: Equatable where TimerCancellationToken: Equatable, Request: Equatable {} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionAction: Equatable where TimerCancellationToken: Equatable { + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.scheduleTimers(let lhs), .scheduleTimers(let rhs)): + return lhs == rhs + case (.makeConnection(let lhsRequest, let lhsToken), .makeConnection(let rhsRequest, let rhsToken)): + return lhsRequest == rhsRequest && lhsToken == rhsToken + case (.runKeepAlive(let lhsConn, let lhsToken), .runKeepAlive(let rhsConn, let rhsToken)): + return lhsConn === rhsConn && lhsToken == rhsToken + case (.closeConnection(let lhsConn, let lhsTimers), .closeConnection(let rhsConn, let rhsTimers)): + return lhsConn === rhsConn && lhsTimers == rhsTimers + case (.shutdown(let lhs), .shutdown(let rhs)): + return lhs == rhs + case (.cancelTimers(let lhs), .cancelTimers(let rhs)): + return lhs == rhs + case (.none, .none), + (.cancelTimers([]), .none), (.none, .cancelTimers([])), + (.scheduleTimers([]), .none), (.none, .scheduleTimers([])): + return true + default: + return false + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionAction.Shutdown: Equatable where TimerCancellationToken: Equatable { + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + Set(lhs.connections.lazy.map(\.id)) == Set(rhs.connections.lazy.map(\.id)) && lhs.timersToCancel == rhs.timersToCancel + } +} + + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.RequestAction: Equatable where Request: Equatable { + + @usableFromInline + static func ==(lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.leaseConnection(let lhsRequests, let lhsConn), .leaseConnection(let rhsRequests, let rhsConn)): + guard lhsRequests.count == rhsRequests.count else { return false } + var lhsIterator = lhsRequests.makeIterator() + var rhsIterator = rhsRequests.makeIterator() + while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() { + guard lhsNext.id == rhsNext.id else { return false } + } + return lhsConn === rhsConn + + case (.failRequest(let lhsRequest, let lhsError), .failRequest(let rhsRequest, let rhsError)): + return lhsRequest.id == rhsRequest.id && lhsError == rhsError + + case (.failRequests(let lhsRequests, let lhsError), .failRequests(let rhsRequests, let rhsError)): + return Set(lhsRequests.lazy.map(\.id)) == Set(rhsRequests.lazy.map(\.id)) && lhsError == rhsError + + case (.none, .none): + return true + + default: + return false + } + } } diff --git a/Sources/ConnectionPoolModule/OneElementFastSequence.swift b/Sources/ConnectionPoolModule/TinyFastSequence.swift similarity index 58% rename from Sources/ConnectionPoolModule/OneElementFastSequence.swift rename to Sources/ConnectionPoolModule/TinyFastSequence.swift index 3c3bfaa0..dff8a30b 100644 --- a/Sources/ConnectionPoolModule/OneElementFastSequence.swift +++ b/Sources/ConnectionPoolModule/TinyFastSequence.swift @@ -1,10 +1,11 @@ /// A `Sequence` that does not heap allocate, if it only carries a single element @usableFromInline -struct OneElementFastSequence: Sequence { +struct TinyFastSequence: Sequence { @usableFromInline enum Base { case none(reserveCapacity: Int) case one(Element, reserveCapacity: Int) + case two(Element, Element, reserveCapacity: Int) case n([Element]) } @@ -37,6 +38,20 @@ struct OneElementFastSequence: Sequence { } } + @inlinable + init(_ max2Sequence: Max2Sequence) { + switch max2Sequence.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(max2Sequence.first!, reserveCapacity: 0) + case 2: + self.base = .n(Array(max2Sequence)) + default: + fatalError() + } + } + @usableFromInline var count: Int { switch self.base { @@ -44,6 +59,8 @@ struct OneElementFastSequence: Sequence { return 0 case .one: return 1 + case .two: + return 2 case .n(let array): return array.count } @@ -56,6 +73,8 @@ struct OneElementFastSequence: Sequence { return nil case .one(let element, _): return element + case .two(let first, _, _): + return first case .n(let array): return array.first } @@ -66,7 +85,7 @@ struct OneElementFastSequence: Sequence { switch self.base { case .none: return true - case .one, .n: + case .one, .two, .n: return false } } @@ -78,6 +97,8 @@ struct OneElementFastSequence: Sequence { self.base = .none(reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) case .one(let element, let reservedCapacity): self.base = .one(element, reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) + case .two(let first, let second, let reservedCapacity): + self.base = .two(first, second, reserveCapacity: Swift.max(reservedCapacity, minimumCapacity)) case .n(var array): self.base = .none(reserveCapacity: 0) // prevent CoW array.reserveCapacity(minimumCapacity) @@ -90,12 +111,17 @@ struct OneElementFastSequence: Sequence { switch self.base { case .none(let reserveCapacity): self.base = .one(element, reserveCapacity: reserveCapacity) - case .one(let existing, let reserveCapacity): + case .one(let first, let reserveCapacity): + self.base = .two(first, element, reserveCapacity: reserveCapacity) + + case .two(let first, let second, let reserveCapacity): var new = [Element]() - new.reserveCapacity(reserveCapacity) - new.append(existing) + new.reserveCapacity(Swift.max(4, reserveCapacity)) + new.append(first) + new.append(second) new.append(element) self.base = .n(new) + case .n(var existing): self.base = .none(reserveCapacity: 0) // prevent CoW existing.append(element) @@ -111,10 +137,10 @@ struct OneElementFastSequence: Sequence { @usableFromInline struct Iterator: IteratorProtocol { @usableFromInline private(set) var index: Int = 0 - @usableFromInline private(set) var backing: OneElementFastSequence + @usableFromInline private(set) var backing: TinyFastSequence @inlinable - init(_ backing: OneElementFastSequence) { + init(_ backing: TinyFastSequence) { self.backing = backing } @@ -130,6 +156,17 @@ struct OneElementFastSequence: Sequence { } return nil + case .two(let first, let second, _): + defer { self.index += 1 } + switch self.index { + case 0: + return first + case 1: + return second + default: + return nil + } + case .n(let array): if self.index < array.endIndex { defer { self.index += 1} @@ -141,11 +178,28 @@ struct OneElementFastSequence: Sequence { } } -extension OneElementFastSequence: Equatable where Element: Equatable {} -extension OneElementFastSequence.Base: Equatable where Element: Equatable {} +extension TinyFastSequence: Equatable where Element: Equatable {} +extension TinyFastSequence.Base: Equatable where Element: Equatable {} + +extension TinyFastSequence: Hashable where Element: Hashable {} +extension TinyFastSequence.Base: Hashable where Element: Hashable {} -extension OneElementFastSequence: Hashable where Element: Hashable {} -extension OneElementFastSequence.Base: Hashable where Element: Hashable {} +extension TinyFastSequence: Sendable where Element: Sendable {} +extension TinyFastSequence.Base: Sendable where Element: Sendable {} -extension OneElementFastSequence: Sendable where Element: Sendable {} -extension OneElementFastSequence.Base: Sendable where Element: Sendable {} +extension TinyFastSequence: ExpressibleByArrayLiteral { + @inlinable + init(arrayLiteral elements: Element...) { + var iterator = elements.makeIterator() + switch elements.count { + case 0: + self.base = .none(reserveCapacity: 0) + case 1: + self.base = .one(iterator.next()!, reserveCapacity: 0) + case 2: + self.base = .two(iterator.next()!, iterator.next()!, reserveCapacity: 0) + default: + self.base = .n(elements) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 4e3a1647..bf385918 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -120,7 +120,7 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssertEqual(connections.stats, .init(closing: 1, availableStreams: 0)) let closeContext = connections.connectionClosed(newConnection.id) - XCTAssertEqual(closeContext?.connectionsStarting, 0) + XCTAssertEqual(closeContext.connectionsStarting, 0) XCTAssertTrue(connections.isEmpty) XCTAssertEqual(connections.stats, .init()) } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index ee8cfdc6..0f3af728 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,5 +1,3 @@ -import NIOCore -import NIOEmbedded import XCTest @testable import _ConnectionPoolModule @@ -12,3 +10,218 @@ typealias TestPoolStateMachine = PoolStateMachine< MockRequest.ID, MockTimerCancellationToken > + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PoolStateMachineTests: XCTestCase { + + func testConnectionsAreCreatedAndParkedOnStartup() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 2 + configuration.maximumConnectionSoftLimit = 4 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = .seconds(10) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + let connection2 = MockConnection(id: 1) + + do { + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 2) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + let connection1KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 0, usecase: .keepAlive), duration: .seconds(10)) + let connection1KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection1KeepAliveTimer) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([connection1KeepAliveTimer])) + + XCTAssertEqual(stateMachine.timerScheduled(connection1KeepAliveTimer, cancelContinuation: connection1KeepAliveTimerCancellationToken), .none) + + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + let connection2KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: .seconds(10)) + let connection2KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection2KeepAliveTimer) + XCTAssertEqual(createdAction2.request, .none) + XCTAssertEqual(createdAction2.connection, .scheduleTimers([connection2KeepAliveTimer])) + XCTAssertEqual(stateMachine.timerScheduled(connection2KeepAliveTimer, cancelContinuation: connection2KeepAliveTimerCancellationToken), .none) + } + } + + func testConnectionsNoKeepAliveRun() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 4 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(5) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + + // refill pool to at least one connection + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + + // lease connection 1 + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + + // release connection 1 + XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) + + // lease connection 1 + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) + + // request connection while none is available + let request3 = MockRequest() + let leaseRequest3 = stateMachine.leaseConnection(request3) + XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest3.request, .none) + + // make connection 2 and lease immediately + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request3), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + + // release connection 2 + let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) + let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) + XCTAssertEqual( + stateMachine.releaseConnection(connection2, streams: 1), + .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) + ) + + XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) + XCTAssertEqual(stateMachine.timerTriggered(connection2IdleTimer), .init(request: .none, connection: .closeConnection(connection2, [connection2IdleTimerCancellationToken]))) + } + + func testOnlyOverflowConnections() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 0 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 and lease immediately + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + + // request connection while none is available + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // release connection 1 should be leased again immediately + let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest1.request, .leaseConnection(.init(element: request2), connection1)) + XCTAssertEqual(releaseRequest1.connection, .none) + + // connection 2 comes up and should be closed right away + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .none) + XCTAssertEqual(createdAction2.connection, .closeConnection(connection2, [])) + XCTAssertEqual(stateMachine.connectionClosed(connection2), .none()) + + // release connection 1 should be closed as well + let releaseRequest2 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest2.request, .none) + XCTAssertEqual(releaseRequest2.connection, .closeConnection(connection1, [])) + + let shutdownAction = stateMachine.triggerForceShutdown() + XCTAssertEqual(shutdownAction.request, .failRequests(.init(), .poolShutdown)) + XCTAssertEqual(shutdownAction.connection, .shutdown(.init())) + } + + func testDemandConnectionIsMadePermanentIfPermanentIsClose() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 6 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + let connection1 = MockConnection(id: 0) + + // refill pool to at least one connection + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .none) + XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + + // lease connection 1 + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) + XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + + // request connection while none is available + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // make connection 2 and lease immediately + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + + // release connection 2 + let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) + let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) + XCTAssertEqual( + stateMachine.releaseConnection(connection2, streams: 1), + .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) + ) + + XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) + + // connection 1 is dropped + XCTAssertEqual(stateMachine.connectionClosed(connection1), .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) + } +} diff --git a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequence.swift similarity index 82% rename from Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift rename to Tests/ConnectionPoolModuleTests/TinyFastSequence.swift index a086341e..b3f8179d 100644 --- a/Tests/ConnectionPoolModuleTests/OneElementFastSequence.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequence.swift @@ -3,7 +3,7 @@ import XCTest final class OneElementFastSequenceTests: XCTestCase { func testCountIsEmptyAndIterator() async { - var sequence = OneElementFastSequence() + var sequence = TinyFastSequence() XCTAssertEqual(sequence.count, 0) XCTAssertEqual(sequence.isEmpty, true) XCTAssertEqual(sequence.first, nil) @@ -26,24 +26,26 @@ final class OneElementFastSequenceTests: XCTestCase { } func testReserveCapacityIsForwarded() { - var emptySequence = OneElementFastSequence() + var emptySequence = TinyFastSequence() emptySequence.reserveCapacity(8) emptySequence.append(1) emptySequence.append(2) + emptySequence.append(3) guard case .n(let array) = emptySequence.base else { return XCTFail("Expected sequence to be backed by an array") } XCTAssertEqual(array.capacity, 8) - var oneElemSequence = OneElementFastSequence(element: 1) + var oneElemSequence = TinyFastSequence(element: 1) oneElemSequence.reserveCapacity(8) oneElemSequence.append(2) + oneElemSequence.append(3) guard case .n(let array) = oneElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } XCTAssertEqual(array.capacity, 8) - var twoElemSequence = OneElementFastSequence([1, 2]) + var twoElemSequence = TinyFastSequence([1, 2]) twoElemSequence.reserveCapacity(8) guard case .n(let array) = twoElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") @@ -52,17 +54,17 @@ final class OneElementFastSequenceTests: XCTestCase { } func testNewSequenceSlowPath() { - let sequence = OneElementFastSequence("AB".utf8) + let sequence = TinyFastSequence("AB".utf8) XCTAssertEqual(Array(sequence), [UInt8(ascii: "A"), UInt8(ascii: "B")]) } func testSingleItem() { - let sequence = OneElementFastSequence("A".utf8) + let sequence = TinyFastSequence("A".utf8) XCTAssertEqual(Array(sequence), [UInt8(ascii: "A")]) } func testEmptyCollection() { - let sequence = OneElementFastSequence("".utf8) + let sequence = TinyFastSequence("".utf8) XCTAssertTrue(sequence.isEmpty) XCTAssertEqual(sequence.count, 0) XCTAssertEqual(Array(sequence), []) From 468ae25f310e877b6613058e8ad2750cfe11f5d8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 27 Oct 2023 08:16:30 +0200 Subject: [PATCH 204/292] Land ConnectionPool (#428) --- .../ConnectionPoolModule/ConnectionPool.swift | 484 +++++++++++++++++- .../ConnectionRequest.swift | 53 ++ .../NIOLockedValueBox.swift | 46 ++ .../PoolStateMachine+ConnectionGroup.swift | 7 +- .../PoolStateMachine.swift | 61 ++- .../ConnectionPoolTests.swift | 189 +++++++ .../Mocks/MockClock.swift | 186 +++++++ .../Mocks/MockConnection.swift | 73 +++ .../Mocks/MockPingPongBehaviour.swift | 14 + ...oolStateMachine+ConnectionGroupTests.swift | 4 +- .../PoolStateMachineTests.swift | 42 ++ ...ence.swift => TinyFastSequenceTests.swift} | 2 +- 12 files changed, 1135 insertions(+), 26 deletions(-) create mode 100644 Sources/ConnectionPoolModule/NIOLockedValueBox.swift create mode 100644 Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift rename Tests/ConnectionPoolModuleTests/{TinyFastSequence.swift => TinyFastSequenceTests.swift} (97%) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 825c3ab3..5571e617 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -1,3 +1,17 @@ + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public struct ConnectionAndMetadata { + + public var connection: Connection + + public var maximalStreamsOnConnection: UInt16 + + public init(connection: Connection, maximalStreamsOnConnection: UInt16) { + self.connection = connection + self.maximalStreamsOnConnection = maximalStreamsOnConnection + } +} + /// A connection that can be pooled in a ``ConnectionPool`` public protocol PooledConnection: AnyObject, Sendable { /// The connections identifier type. @@ -78,7 +92,7 @@ public protocol ConnectionRequestProtocol: Sendable { } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -public struct ConnectionPoolConfiguration { +public struct ConnectionPoolConfiguration: Sendable { /// The minimum number of connections to preserve in the pool. /// /// If the pool is mostly idle and the remote servers closes @@ -114,3 +128,471 @@ public struct ConnectionPoolConfiguration { self.idleTimeout = .seconds(60) } } + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +public final class ConnectionPool< + Connection: PooledConnection, + ConnectionID: Hashable & Sendable, + ConnectionIDGenerator: ConnectionIDGeneratorProtocol, + Request: ConnectionRequestProtocol, + RequestID: Hashable & Sendable, + KeepAliveBehavior: ConnectionKeepAliveBehavior, + ObservabilityDelegate: ConnectionPoolObservabilityDelegate, + Clock: _Concurrency.Clock +>: Sendable where + Connection.ID == ConnectionID, + ConnectionIDGenerator.ID == ConnectionID, + Request.Connection == Connection, + Request.ID == RequestID, + KeepAliveBehavior.Connection == Connection, + ObservabilityDelegate.ConnectionID == ConnectionID, + Clock.Duration == Duration +{ + public typealias ConnectionFactory = @Sendable (ConnectionID, ConnectionPool) async throws -> ConnectionAndMetadata + + @usableFromInline + typealias StateMachine = PoolStateMachine> + + @usableFromInline + let factory: ConnectionFactory + + @usableFromInline + let keepAliveBehavior: KeepAliveBehavior + + @usableFromInline + let observabilityDelegate: ObservabilityDelegate + + @usableFromInline + let clock: Clock + + @usableFromInline + let configuration: ConnectionPoolConfiguration + + @usableFromInline + struct State: Sendable { + @usableFromInline + var stateMachine: StateMachine + @usableFromInline + var lastConnectError: (any Error)? + } + + @usableFromInline let stateBox: NIOLockedValueBox + + private let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() + + @usableFromInline + let eventStream: AsyncStream + + @usableFromInline + let eventContinuation: AsyncStream.Continuation + + public init( + configuration: ConnectionPoolConfiguration, + idGenerator: ConnectionIDGenerator, + requestType: Request.Type, + keepAliveBehavior: KeepAliveBehavior, + observabilityDelegate: ObservabilityDelegate, + clock: Clock, + connectionFactory: @escaping ConnectionFactory + ) { + self.clock = clock + self.factory = connectionFactory + self.keepAliveBehavior = keepAliveBehavior + self.observabilityDelegate = observabilityDelegate + self.configuration = configuration + var stateMachine = StateMachine( + configuration: .init(configuration, keepAliveBehavior: keepAliveBehavior), + generator: idGenerator, + timerCancellationTokenType: CheckedContinuation.self + ) + + let (stream, continuation) = AsyncStream.makeStream(of: NewPoolActions.self) + self.eventStream = stream + self.eventContinuation = continuation + + let connectionRequests = stateMachine.refillConnections() + + self.stateBox = NIOLockedValueBox(.init(stateMachine: stateMachine)) + + for request in connectionRequests { + self.eventContinuation.yield(.makeConnection(request)) + } + } + + @inlinable + public func releaseConnection(_ connection: Connection, streams: UInt16 = 1) { + self.modifyStateAndRunActions { state in + state.stateMachine.releaseConnection(connection, streams: streams) + } + } + + @inlinable + public func leaseConnection(_ request: Request) { + self.modifyStateAndRunActions { state in + state.stateMachine.leaseConnection(request) + } + } + + @inlinable + public func leaseConnections(_ requests: some Collection) { + let actions = self.stateBox.withLockedValue { state in + var actions = [StateMachine.Action]() + actions.reserveCapacity(requests.count) + + for request in requests { + let stateMachineAction = state.stateMachine.leaseConnection(request) + actions.append(stateMachineAction) + } + + return actions + } + + for action in actions { + self.runRequestAction(action.request) + self.runConnectionAction(action.connection) + } + } + + public func cancelLeaseConnection(_ requestID: RequestID) { + self.modifyStateAndRunActions { state in + state.stateMachine.cancelRequest(id: requestID) + } + } + + /// Mark a connection as going away. Connection implementors have to call this method if the connection + /// has received a close intent from the server. For example: an HTTP/2 GOWAY frame. + public func connectionWillClose(_ connection: Connection) { + + } + + public func connection(_ connection: Connection, didReceiveNewMaxStreamSetting: UInt16) { + + } + + public func run() async { + await withTaskCancellationHandler { + #if swift(>=5.8) && os(Linux) || swift(>=5.9) + if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { + return await withDiscardingTaskGroup() { taskGroup in + await self.run(in: &taskGroup) + } + } + #endif + return await withTaskGroup(of: Void.self) { taskGroup in + await self.run(in: &taskGroup) + } + } onCancel: { + let actions = self.stateBox.withLockedValue { state in + state.stateMachine.triggerForceShutdown() + } + + self.runStateMachineActions(actions) + } + } + + // MARK: - Private Methods - + + @inlinable + func connectionDidClose(_ connection: Connection, error: (any Error)?) { + self.observabilityDelegate.connectionClosed(id: connection.id, error: error) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionClosed(connection) + } + } + + // MARK: Events + + @usableFromInline + enum NewPoolActions: Sendable { + case makeConnection(StateMachine.ConnectionRequest) + case closeConnection(Connection) + case runKeepAlive(Connection) + + case scheduleTimer(StateMachine.Timer) + } + + #if swift(>=5.8) && os(Linux) || swift(>=5.9) + @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) + private func run(in taskGroup: inout DiscardingTaskGroup) async { + for await event in self.eventStream { + self.runEvent(event, in: &taskGroup) + } + } + #endif + + private func run(in taskGroup: inout TaskGroup) async { + var running = 0 + for await event in self.eventStream { + running += 1 + self.runEvent(event, in: &taskGroup) + + if running == 100 { + _ = await taskGroup.next() + running -= 1 + } + } + } + + private func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { + switch event { + case .makeConnection(let request): + self.makeConnection(for: request, in: &taskGroup) + + case .runKeepAlive(let connection): + self.runKeepAlive(connection, in: &taskGroup) + + case .closeConnection(let connection): + self.closeConnection(connection) + + case .scheduleTimer(let timer): + self.runTimer(timer, in: &taskGroup) + } + } + + // MARK: Run actions + + @inlinable + /*private*/ func modifyStateAndRunActions(_ closure: (inout State) -> StateMachine.Action) { + let actions = self.stateBox.withLockedValue { state -> StateMachine.Action in + closure(&state) + } + self.runStateMachineActions(actions) + } + + @inlinable + /*private*/ func runStateMachineActions(_ actions: StateMachine.Action) { + self.runConnectionAction(actions.connection) + self.runRequestAction(actions.request) + } + + @inlinable + /*private*/ func runConnectionAction(_ action: StateMachine.ConnectionAction) { + switch action { + case .makeConnection(let request, let timers): + self.cancelTimers(timers) + self.eventContinuation.yield(.makeConnection(request)) + + case .runKeepAlive(let connection, let cancelContinuation): + cancelContinuation?.resume(returning: ()) + self.eventContinuation.yield(.runKeepAlive(connection)) + + case .scheduleTimers(let timers): + for timer in timers { + self.eventContinuation.yield(.scheduleTimer(timer)) + } + + case .cancelTimers(let timers): + self.cancelTimers(timers) + + case .closeConnection(let connection, let timers): + self.closeConnection(connection) + self.cancelTimers(timers) + + case .shutdown(let cleanup): + for connection in cleanup.connections { + self.closeConnection(connection) + } + self.cancelTimers(cleanup.timersToCancel) + + case .none: + break + } + } + + @inlinable + /*private*/ func runRequestAction(_ action: StateMachine.RequestAction) { + switch action { + case .leaseConnection(let requests, let connection): + for request in requests { + request.complete(with: .success(connection)) + } + + case .failRequest(let request, let error): + request.complete(with: .failure(error)) + + case .failRequests(let requests, let error): + for request in requests { request.complete(with: .failure(error)) } + + case .none: + break + } + } + + @inlinable + /*private*/ func makeConnection(for request: StateMachine.ConnectionRequest, in taskGroup: inout some TaskGroupProtocol) { + taskGroup.addTask { + self.observabilityDelegate.startedConnecting(id: request.connectionID) + + do { + let bundle = try await self.factory(request.connectionID, self) + self.connectionEstablished(bundle) + bundle.connection.onClose { + self.connectionDidClose(bundle.connection, error: $0) + } + } catch { + self.connectionEstablishFailed(error, for: request) + } + } + } + + @inlinable + /*private*/ func connectionEstablished(_ connectionBundle: ConnectionAndMetadata) { + self.observabilityDelegate.connectSucceeded(id: connectionBundle.connection.id, streamCapacity: connectionBundle.maximalStreamsOnConnection) + + self.modifyStateAndRunActions { state in + state.lastConnectError = nil + return state.stateMachine.connectionEstablished( + connectionBundle.connection, + maxStreams: connectionBundle.maximalStreamsOnConnection + ) + } + } + + @inlinable + /*private*/ func connectionEstablishFailed(_ error: Error, for request: StateMachine.ConnectionRequest) { + self.observabilityDelegate.connectFailed(id: request.connectionID, error: error) + + self.modifyStateAndRunActions { state in + state.lastConnectError = error + return state.stateMachine.connectionEstablishFailed(error, for: request) + } + } + + @inlinable + /*private*/ func runKeepAlive(_ connection: Connection, in taskGroup: inout some TaskGroupProtocol) { + self.observabilityDelegate.keepAliveTriggered(id: connection.id) + + taskGroup.addTask { + do { + try await self.keepAliveBehavior.runKeepAlive(for: connection) + + self.observabilityDelegate.keepAliveSucceeded(id: connection.id) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionKeepAliveDone(connection) + } + } catch { + self.observabilityDelegate.keepAliveFailed(id: connection.id, error: error) + + self.modifyStateAndRunActions { state in + state.stateMachine.connectionClosed(connection) + } + } + } + } + + @inlinable + /*private*/ func closeConnection(_ connection: Connection) { + self.observabilityDelegate.connectionClosing(id: connection.id) + + connection.close() + } + + @usableFromInline + enum TimerRunResult { + case timerTriggered + case timerCancelled + case cancellationContinuationFinished + } + + @inlinable + /*private*/ func runTimer(_ timer: StateMachine.Timer, in poolGroup: inout some TaskGroupProtocol) { + poolGroup.addTask { () async -> () in + await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in + taskGroup.addTask { + do { + #if swift(>=5.8) && os(Linux) || swift(>=5.9) + try await self.clock.sleep(for: timer.duration) + #else + try await self.clock.sleep(until: self.clock.now.advanced(by: timer.duration), tolerance: nil) + #endif + return .timerTriggered + } catch { + return .timerCancelled + } + } + + taskGroup.addTask { + await withCheckedContinuation { (continuation: CheckedContinuation) in + let continuation = self.stateBox.withLockedValue { state in + state.stateMachine.timerScheduled(timer, cancelContinuation: continuation) + } + + continuation?.resume(returning: ()) + } + + return .cancellationContinuationFinished + } + + switch await taskGroup.next()! { + case .cancellationContinuationFinished: + taskGroup.cancelAll() + + case .timerTriggered: + let action = self.stateBox.withLockedValue { state in + state.stateMachine.timerTriggered(timer) + } + + self.runStateMachineActions(action) + + case .timerCancelled: + // the only way to reach this, is if the state machine decided to cancel the + // timer. therefore we don't need to report it back! + break + } + + return + } + } + } + + @inlinable + /*private*/ func cancelTimers(_ cancellationTokens: some Sequence>) { + for token in cancellationTokens { + token.resume() + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolConfiguration { + init(_ configuration: ConnectionPoolConfiguration, keepAliveBehavior: KeepAliveBehavior) { + self.minimumConnectionCount = configuration.minimumConnectionCount + self.maximumConnectionSoftLimit = configuration.maximumConnectionSoftLimit + self.maximumConnectionHardLimit = configuration.maximumConnectionHardLimit + self.keepAliveDuration = keepAliveBehavior.keepAliveFrequency + self.idleTimeoutDuration = configuration.idleTimeout + } +} + +#if swift(<5.9) +// This should be removed once we support Swift 5.9+ only +extension AsyncStream { + static func makeStream( + of elementType: Element.Type = Element.self, + bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded + ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { + var continuation: AsyncStream.Continuation! + let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } + return (stream: stream, continuation: continuation!) + } +} +#endif + +@usableFromInline +protocol TaskGroupProtocol { + mutating func addTask(operation: @escaping @Sendable () async -> Void) +} + +#if swift(>=5.8) && os(Linux) || swift(>=5.9) +@available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 9.0, *) +extension DiscardingTaskGroup: TaskGroupProtocol {} +#endif + +extension TaskGroup: TaskGroupProtocol { + @inlinable + mutating func addTask(operation: @escaping @Sendable () async -> Void) { + self.addTask(priority: nil, operation: operation) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index fd01bb76..19ed9bd2 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -20,3 +20,56 @@ public struct ConnectionRequest: ConnectionRequest self.continuation.resume(with: result) } } + +fileprivate let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension ConnectionPool where Request == ConnectionRequest { + public convenience init( + configuration: ConnectionPoolConfiguration, + idGenerator: ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator(), + keepAliveBehavior: KeepAliveBehavior, + observabilityDelegate: ObservabilityDelegate, + clock: Clock = ContinuousClock(), + connectionFactory: @escaping ConnectionFactory + ) { + self.init( + configuration: configuration, + idGenerator: idGenerator, + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAliveBehavior, + observabilityDelegate: observabilityDelegate, + clock: clock, + connectionFactory: connectionFactory + ) + } + + public func leaseConnection() async throws -> Connection { + let requestID = requestIDGenerator.next() + + let connection = try await withTaskCancellationHandler { + if Task.isCancelled { + throw CancellationError() + } + + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let request = Request( + id: requestID, + continuation: continuation + ) + + self.leaseConnection(request) + } + } onCancel: { + self.cancelLeaseConnection(requestID) + } + + return connection + } + + public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { + let connection = try await self.leaseConnection() + defer { self.releaseConnection(connection) } + return try await closure(connection) + } +} diff --git a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift new file mode 100644 index 00000000..e5a3e6a2 --- /dev/null +++ b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift @@ -0,0 +1,46 @@ +// Implementation vendored from SwiftNIO: +// https://github.com/apple/swift-nio + +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// Provides locked access to `Value`. +/// +/// - note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` +/// alongside a lock behind a reference. +/// +/// This is no different than creating a ``Lock`` and protecting all +/// accesses to a value using the lock. But it's easy to forget to actually +/// acquire/release the lock in the correct place. ``NIOLockedValueBox`` makes +/// that much easier. +@usableFromInline +struct NIOLockedValueBox { + + @usableFromInline + internal let _storage: LockStorage + + /// Initialize the `Value`. + @inlinable + init(_ value: Value) { + self._storage = .create(value: value) + } + + /// Access the `Value`, allowing mutation of it. + @inlinable + func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { + return try self._storage.withLockedValue(mutate) + } +} + +extension NIOLockedValueBox: Sendable where Value: Sendable {} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 16970599..e735d277 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -342,9 +342,9 @@ extension PoolStateMachine { /// Call ``leaseConnection(at:)`` or ``closeConnection(at:)`` with the supplied index after /// this. If you want to park the connection no further call is required. @inlinable - mutating func releaseConnection(_ connectionID: Connection.ID, streams: UInt16) -> (Int, AvailableConnectionContext) { + mutating func releaseConnection(_ connectionID: Connection.ID, streams: UInt16) -> (Int, AvailableConnectionContext)? { guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { - preconditionFailure("A connection that we don't know was released? Something is very wrong...") + return nil } let connectionInfo = self.connections[index].release(streams: streams) @@ -657,3 +657,6 @@ extension PoolStateMachine { @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) extension PoolStateMachine.ConnectionGroup.BackoffDoneAction: Equatable where TimerCancellationToken: Equatable {} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PoolStateMachine.ConnectionGroup.ClosedAction: Equatable where TimerCancellationToken: Equatable {} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index aa62d749..4cd78c0e 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -234,7 +234,9 @@ struct PoolStateMachine< @inlinable mutating func releaseConnection(_ connection: Connection, streams: UInt16) -> Action { - let (index, context) = self.connections.releaseConnection(connection.id, streams: streams) + guard let (index, context) = self.connections.releaseConnection(connection.id, streams: streams) else { + return .none() + } return self.handleAvailableConnection(index: index, availableContext: context) } @@ -251,8 +253,13 @@ struct PoolStateMachine< @inlinable mutating func connectionEstablished(_ connection: Connection, maxStreams: UInt16) -> Action { - let (index, context) = self.connections.newConnectionEstablished(connection, maxStreams: maxStreams) - return self.handleAvailableConnection(index: index, availableContext: context) + switch self.poolState { + case .running, .shuttingDown(graceful: true): + let (index, context) = self.connections.newConnectionEstablished(connection, maxStreams: maxStreams) + return self.handleAvailableConnection(index: index, availableContext: context) + case .shuttingDown(graceful: false), .shutDown: + return .init(request: .none, connection: .closeConnection(connection, [])) + } } @inlinable @@ -274,31 +281,43 @@ struct PoolStateMachine< @inlinable mutating func connectionEstablishFailed(_ error: Error, for request: ConnectionRequest) -> Action { - self.failedConsecutiveConnectionAttempts += 1 + switch self.poolState { + case .running, .shuttingDown(graceful: true): + self.failedConsecutiveConnectionAttempts += 1 - let connectionTimer = self.connections.backoffNextConnectionAttempt(request.connectionID) - let backoff = Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) - let timer = Timer(connectionTimer, duration: backoff) - return .init(request: .none, connection: .scheduleTimers(.init(timer))) + let connectionTimer = self.connections.backoffNextConnectionAttempt(request.connectionID) + let backoff = Self.calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + let timer = Timer(connectionTimer, duration: backoff) + return .init(request: .none, connection: .scheduleTimers(.init(timer))) + + case .shuttingDown(graceful: false), .shutDown: + return .none() + } } @inlinable mutating func connectionCreationBackoffDone(_ connectionID: ConnectionID) -> Action { - let soonAvailable = self.connections.soonAvailableConnections - let retry = (soonAvailable - 1) < self.requestQueue.count - - switch self.connections.backoffDone(connectionID, retry: retry) { - case .createConnection(let request, let continuation): - let timers: TinyFastSequence - if let continuation { - timers = .init(element: continuation) - } else { - timers = .init() + switch self.poolState { + case .running, .shuttingDown(graceful: true): + let soonAvailable = self.connections.soonAvailableConnections + let retry = (soonAvailable - 1) < self.requestQueue.count + + switch self.connections.backoffDone(connectionID, retry: retry) { + case .createConnection(let request, let continuation): + let timers: TinyFastSequence + if let continuation { + timers = .init(element: continuation) + } else { + timers = .init() + } + return .init(request: .none, connection: .makeConnection(request, timers)) + + case .cancelTimers(let timers): + return .init(request: .none, connection: .cancelTimers(.init(timers))) } - return .init(request: .none, connection: .makeConnection(request, timers)) - case .cancelTimers(let timers): - return .init(request: .none, connection: .cancelTimers(.init(timers))) + case .shuttingDown(graceful: false), .shutDown: + return .none() } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift new file mode 100644 index 00000000..b27fff37 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -0,0 +1,189 @@ +@testable import _ConnectionPoolModule +import XCTest +import NIOEmbedded + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class ConnectionPoolTests: XCTestCase { + + func test1000ConsecutiveRequestsOnSingleConnection() async { + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: ContinuousClock() + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + // the same connection is reused 1000 times + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + let createdConnection = await factory.nextConnectAttempt { _ in + return 1 + } + XCTAssertNotNil(createdConnection) + + do { + for _ in 0..<1000 { + async let connectionFuture = try await pool.leaseConnection() + var leasedConnection: MockConnection? + XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + leasedConnection = try await connectionFuture + XCTAssertNotNil(leasedConnection) + XCTAssert(createdConnection === leasedConnection) + + if let leasedConnection { + pool.releaseConnection(leasedConnection) + } + } + } catch { + XCTFail("Unexpected error: \(error)") + } + + taskGroup.cancelAll() + } + } + + func testShutdownPoolWhileConnectionIsBeingCreated() async { + let clock = MockClock() + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + let (blockCancelStream, blockCancelContinuation) = AsyncStream.makeStream(of: Void.self) + let (blockConnCreationStream, blockConnCreationContinuation) = AsyncStream.makeStream(of: Void.self) + + taskGroup.addTask { + _ = try? await factory.nextConnectAttempt { _ in + blockCancelContinuation.yield() + var iterator = blockConnCreationStream.makeAsyncIterator() + await iterator.next() + throw ConnectionCreationError() + } + } + + var iterator = blockCancelStream.makeAsyncIterator() + await iterator.next() + + taskGroup.cancelAll() + blockConnCreationContinuation.yield() + } + + struct ConnectionCreationError: Error {} + } + + func testShutdownPoolWhileConnectionIsBackingOff() async { + let clock = MockClock() + let factory = MockConnectionFactory() + + var config = ConnectionPoolConfiguration() + config.minimumConnectionCount = 1 + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + _ = try? await factory.nextConnectAttempt { _ in + throw ConnectionCreationError() + } + + await clock.timerScheduled() + + taskGroup.cancelAll() + } + + struct ConnectionCreationError: Error {} + } + + func testConnectionHardLimitIsRespected() async { + let factory = MockConnectionFactory() + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 8 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: ContinuousClock() + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + // the same connection is reused 1000 times + + await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + taskGroup.addTask { + var usedConnectionIDs = Set() + for _ in 0.. Self { + .init(self.base + duration) + } + + func duration(to other: Self) -> Self.Duration { + self.base - other.base + } + + private var base: Swift.Duration + + init(_ base: Duration) { + self.base = base + } + + static func < (lhs: Self, rhs: Self) -> Bool { + lhs.base < rhs.base + } + + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.base == rhs.base + } + } + + private struct State: Sendable { + var now: Instant + + var sleepersHeap: Array + + var waitersHeap: Array + + init() { + self.now = .init(.seconds(0)) + self.sleepersHeap = Array() + self.waitersHeap = Array() + } + } + + private struct Waiter { + var expectedSleepers: Int + + var continuation: CheckedContinuation + } + + private struct Sleeper { + var id: Int + + var deadline: Instant + + var continuation: CheckedContinuation + } + + typealias Duration = Swift.Duration + + var minimumResolution: Duration { .nanoseconds(1) } + + var now: Instant { self.stateBox.withLockedValue { $0.now } } + + private let stateBox = NIOLockedValueBox(State()) + private let waiterIDGenerator = ManagedAtomic(0) + + func sleep(until deadline: Instant, tolerance: Duration?) async throws { + let waiterID = self.waiterIDGenerator.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + enum SleepAction { + case none + case resume + case cancel + } + + let action = self.stateBox.withLockedValue { state -> (SleepAction, ArraySlice) in + state.waitersHeap = state.waitersHeap.map { waiter in + var waiter = waiter; waiter.expectedSleepers -= 1; return waiter + } + let slice: ArraySlice + let lastRemainingIndex = state.waitersHeap.firstIndex(where: { $0.expectedSleepers > 0 }) + if let lastRemainingIndex { + slice = state.waitersHeap[0..= deadline { + return (.resume, slice) + } + + let newWaiter = Sleeper(id: waiterID, deadline: deadline, continuation: continuation) + + if let index = state.sleepersHeap.lastIndex(where: { $0.deadline < deadline }) { + state.sleepersHeap.insert(newWaiter, at: index + 1) + } else { + state.sleepersHeap.append(newWaiter) + } + + return (.none, slice) + } + + switch action.0 { + case .cancel: + continuation.resume(throwing: CancellationError()) + case .resume: + continuation.resume() + case .none: + break + } + + for waiter in action.1 { + waiter.continuation.resume() + } + } + } onCancel: { + let continuation = self.stateBox.withLockedValue { state -> CheckedContinuation? in + if let index = state.sleepersHeap.firstIndex(where: { $0.id == waiterID }) { + return state.sleepersHeap.remove(at: index).continuation + } + return nil + } + continuation?.resume(throwing: CancellationError()) + } + } + + func timerScheduled(n: Int = 1) async { + precondition(n >= 1, "At least one new sleep must be awaited") + await withCheckedContinuation { (continuation: CheckedContinuation<(), Never>) in + let result = self.stateBox.withLockedValue { state -> Bool in + let n = n - state.sleepersHeap.count + + if n <= 0 { + return true + } + + let waiter = Waiter(expectedSleepers: n, continuation: continuation) + + if let index = state.waitersHeap.firstIndex(where: { $0.expectedSleepers > n }) { + state.waitersHeap.insert(waiter, at: index) + } else { + state.waitersHeap.append(waiter) + } + return false + } + + if result { + continuation.resume() + } + } + } + + func advance(to deadline: Instant) { + let waiters = self.stateBox.withLockedValue { state -> ArraySlice in + precondition(deadline > state.now, "Time can only move forward") + state.now = deadline + + if let newFirstIndex = state.sleepersHeap.firstIndex(where: { $0.deadline > deadline }) { + defer { state.sleepersHeap.removeFirst(newFirstIndex) } + return state.sleepersHeap[0.. where Clock.Duration == Duration { + typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + typealias Request = ConnectionRequest + typealias KeepAliveBehavior = MockPingPongBehavior + typealias MetricsDelegate = NoOpConnectionPoolMetrics + typealias ConnectionID = Int + typealias Connection = MockConnection + + let stateBox = NIOLockedValueBox(State()) + + struct State { + var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() + + var waiter = Deque), Never>>() + } + + var pendingConnectionAttemptsCount: Int { + self.stateBox.withLockedValue { $0.attempts.count } + } + + func makeConnection( + id: Int, + for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> + ) async throws -> ConnectionAndMetadata { + // we currently don't support cancellation when creating a connection + let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in + let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in + if let waiter = state.waiter.popFirst() { + return waiter + } else { + state.attempts.append((id, checkedContinuation)) + return nil + } + } + + if let waiter { + waiter.resume(returning: (id, checkedContinuation)) + } + } + + return .init(connection: result.0, maximalStreamsOnConnection: result.1) + } + + @discardableResult + func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in + let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in + if let attempt = state.attempts.popFirst() { + return attempt + } else { + state.waiter.append(continuation) + return nil + } + } + + if let attempt { + continuation.resume(returning: attempt) + } + } + + do { + let streamCount = try await closure(connectionID) + let connection = MockConnection(id: connectionID) + continuation.resume(returning: (connection, streamCount)) + return connection + } catch { + continuation.resume(throwing: error) + throw error + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift new file mode 100644 index 00000000..2ee9b7a0 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift @@ -0,0 +1,14 @@ +import _ConnectionPoolModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct MockPingPongBehavior: ConnectionKeepAliveBehavior { + let keepAliveFrequency: Duration? + + init(keepAliveFrequency: Duration?) { + self.keepAliveFrequency = keepAliveFrequency + } + + func runKeepAlive(for connection: MockConnection) async throws { + preconditionFailure() + } +} diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index bf385918..99b73fd0 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -88,7 +88,9 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssert(newConnection === leaseResult.connection) XCTAssertEqual(connections.stats, .init(leased: 1, leasedStreams: 1)) - let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) + guard let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) else { + return XCTFail("Expected that this connection is still active") + } XCTAssertEqual(releasedContext.info, .idle(availableStreams: 1, newIdle: true)) XCTAssertEqual(releasedContext.use, .demand) XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index 0f3af728..a19d2326 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -224,4 +224,46 @@ final class PoolStateMachineTests: XCTestCase { // connection 1 is dropped XCTAssertEqual(stateMachine.connectionClosed(connection1), .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) } + + func testReleaseLoosesRaceAgainstClosed() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = nil + configuration.idleTimeoutDuration = .seconds(3) + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 and lease immediately + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + + // connection got closed + let closedAction = stateMachine.connectionClosed(connection1) + XCTAssertEqual(closedAction.connection, .none) + XCTAssertEqual(closedAction.request, .none) + + // release connection 1 should be leased again immediately + let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) + XCTAssertEqual(releaseRequest1.request, .none) + XCTAssertEqual(releaseRequest1.connection, .none) + } + } diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequence.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift similarity index 97% rename from Tests/ConnectionPoolModuleTests/TinyFastSequence.swift rename to Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index b3f8179d..1a2836b9 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequence.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -1,7 +1,7 @@ @testable import _ConnectionPoolModule import XCTest -final class OneElementFastSequenceTests: XCTestCase { +final class TinyFastSequenceTests: XCTestCase { func testCountIsEmptyAndIterator() async { var sequence = TinyFastSequence() XCTAssertEqual(sequence.count, 0) From add68a0aed8d794a5608318452495621d038b255 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sat, 28 Oct 2023 15:23:47 +0200 Subject: [PATCH 205/292] Ensure pool runs until all connections are closed (#429) - Ensure pool runs until all connections are closed - Fix an ordering issue in `RequestQueue` - Remove unused `closeConnection` in NewPoolActions --- .../ConnectionPoolModule/ConnectionPool.swift | 15 ++++--- .../PoolStateMachine+RequestQueue.swift | 2 +- .../PoolStateMachine.swift | 24 ++++++---- .../ConnectionPoolTests.swift | 44 +++++++++++++++++-- .../Mocks/MockConnection.swift | 17 +++++++ 5 files changed, 82 insertions(+), 20 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 5571e617..e9c9c4c9 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -306,7 +306,6 @@ public final class ConnectionPool< @usableFromInline enum NewPoolActions: Sendable { case makeConnection(StateMachine.ConnectionRequest) - case closeConnection(Connection) case runKeepAlive(Connection) case scheduleTimer(StateMachine.Timer) @@ -342,9 +341,6 @@ public final class ConnectionPool< case .runKeepAlive(let connection): self.runKeepAlive(connection, in: &taskGroup) - case .closeConnection(let connection): - self.closeConnection(connection) - case .scheduleTimer(let timer): self.runTimer(timer, in: &taskGroup) } @@ -427,8 +423,15 @@ public final class ConnectionPool< do { let bundle = try await self.factory(request.connectionID, self) self.connectionEstablished(bundle) - bundle.connection.onClose { - self.connectionDidClose(bundle.connection, error: $0) + + // after the connection has been established, we keep the task open. This ensures + // that the pools run method can not be exited before all connections have been + // closed. + await withCheckedContinuation { (continuation: CheckedContinuation) in + bundle.connection.onClose { + self.connectionDidClose(bundle.connection, error: $0) + continuation.resume() + } } } catch { self.connectionEstablishFailed(error, for: request) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift index f1d6f4e4..99ec4896 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+RequestQueue.swift @@ -44,7 +44,7 @@ extension PoolStateMachine { var result = TinyFastSequence() result.reserveCapacity(Int(max)) var popped = 0 - while let requestID = self.queue.popFirst(), popped < max { + while popped < max, let requestID = self.queue.popFirst() { if let requestIndex = self.requests.index(forKey: requestID) { popped += 1 result.append(self.requests.remove(at: requestIndex).value) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 4cd78c0e..4b3680a1 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -355,18 +355,24 @@ struct PoolStateMachine< @inlinable mutating func connectionClosed(_ connection: Connection) -> Action { - self.cacheNoMoreConnectionsAllowed = false + switch self.poolState { + case .running, .shuttingDown(graceful: true): + self.cacheNoMoreConnectionsAllowed = false - let closedConnectionAction = self.connections.connectionClosed(connection.id) + let closedConnectionAction = self.connections.connectionClosed(connection.id) - let connectionAction: ConnectionAction - if let newRequest = closedConnectionAction.newConnectionRequest { - connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel) - } else { - connectionAction = .cancelTimers(closedConnectionAction.timersToCancel) - } + let connectionAction: ConnectionAction + if let newRequest = closedConnectionAction.newConnectionRequest { + connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel) + } else { + connectionAction = .cancelTimers(closedConnectionAction.timersToCancel) + } + + return .init(request: .none, connection: connectionAction) - return .init(request: .none, connection: connectionAction) + case .shuttingDown(graceful: false), .shutDown: + return .none() + } } struct CleanupAction { diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index b27fff37..5be12a1c 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import Atomics import XCTest import NIOEmbedded @@ -52,7 +53,14 @@ final class ConnectionPoolTests: XCTestCase { } taskGroup.cancelAll() + + XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + for connection in factory.runningConnections { + connection.closeIfClosing() + } } + + XCTAssertEqual(factory.runningConnections.count, 0) } func testShutdownPoolWhileConnectionIsBeingCreated() async { @@ -155,11 +163,16 @@ final class ConnectionPoolTests: XCTestCase { try await factory.makeConnection(id: $0, for: $1) } + let hasFinished = ManagedAtomic(false) + let createdConnections = ManagedAtomic(0) + let iterations = 10_000 + // the same connection is reused 1000 times - await withThrowingTaskGroup(of: Void.self) { taskGroup in + await withTaskGroup(of: Void.self) { taskGroup in taskGroup.addTask { await pool.run() + XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original) } taskGroup.addTask { @@ -167,22 +180,45 @@ final class ConnectionPoolTests: XCTestCase { for _ in 0.. where Clock.Duratio var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() var waiter = Deque), Never>>() + + var runningConnections = [ConnectionID: Connection]() } var pendingConnectionAttemptsCount: Int { self.stateBox.withLockedValue { $0.attempts.count } } + var runningConnections: [Connection] { + self.stateBox.withLockedValue { Array($0.runningConnections.values) } + } + func makeConnection( id: Int, for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> @@ -137,6 +143,17 @@ final class MockConnectionFactory where Clock.Duratio do { let streamCount = try await closure(connectionID) let connection = MockConnection(id: connectionID) + + connection.onClose { _ in + self.stateBox.withLockedValue { state in + _ = state.runningConnections.removeValue(forKey: connectionID) + } + } + + self.stateBox.withLockedValue { state in + _ = state.runningConnections[connectionID] = connection + } + continuation.resume(returning: (connection, streamCount)) return connection } catch { From 2905779f4a0ccf7fa59e1e8e951b7a1c31e689e3 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 30 Oct 2023 11:01:48 +0100 Subject: [PATCH 206/292] Land PostgresClient that is backed by a ConnectionPool as SPI (#430) --- .../PoolStateMachine+ConnectionGroup.swift | 9 +- .../PoolStateMachine.swift | 4 +- .../Connection/PostgresConnection.swift | 4 +- .../ConnectionStateMachine.swift | 2 +- Sources/PostgresNIO/New/PSQLError.swift | 16 +- .../PostgresNIO/Pool/ConnectionFactory.swift | 206 ++++++++++ Sources/PostgresNIO/Pool/PostgresClient.swift | 378 ++++++++++++++++++ .../Pool/PostgresClientMetrics.swift | 85 ++++ Sources/PostgresNIO/Postgres+PSQLCompat.swift | 2 + ...oolStateMachine+ConnectionGroupTests.swift | 6 +- .../PostgresClientTests.swift | 66 +++ 11 files changed, 764 insertions(+), 14 deletions(-) create mode 100644 Sources/PostgresNIO/Pool/ConnectionFactory.swift create mode 100644 Sources/PostgresNIO/Pool/PostgresClient.swift create mode 100644 Sources/PostgresNIO/Pool/PostgresClientMetrics.swift create mode 100644 Tests/IntegrationTests/PostgresClientTests.swift diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index e735d277..b53f8d68 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -308,7 +308,7 @@ extension PoolStateMachine { } @inlinable - mutating func parkConnection(at index: Int) -> Max2Sequence { + mutating func parkConnection(at index: Int, hasBecomeIdle newIdle: Bool) -> Max2Sequence { let scheduleIdleTimeoutTimer: Bool switch index { case 0.. + + struct SSLContextCache: Sendable { + enum State { + case none + case producing(TLSConfiguration, [CheckedContinuation]) + case cached(TLSConfiguration, NIOSSLContext) + case failed(TLSConfiguration, any Error) + } + + var state: State = .none + } + + let sslContextBox = NIOLockedValueBox(SSLContextCache()) + + let eventLoopGroup: any EventLoopGroup + + let logger: Logger + + init(config: PostgresClient.Configuration, eventLoopGroup: any EventLoopGroup, logger: Logger) { + self.eventLoopGroup = eventLoopGroup + self.configBox = NIOLockedValueBox(ConfigCache(config: config)) + self.logger = logger + } + + func makeConnection(_ connectionID: PostgresConnection.ID, pool: PostgresClient.Pool) async throws -> PostgresConnection { + let config = try await self.makeConnectionConfig() + + var connectionLogger = self.logger + connectionLogger[postgresMetadataKey: .connectionID] = "\(connectionID)" + + return try await PostgresConnection.connect( + on: self.eventLoopGroup.any(), + configuration: config, + id: connectionID, + logger: connectionLogger + ).get() + } + + func makeConnectionConfig() async throws -> PostgresConnection.Configuration { + let config = self.configBox.withLockedValue { $0.config } + + let tls: PostgresConnection.Configuration.TLS + switch config.tls.base { + case .prefer(let tlsConfiguration): + let sslContext = try await self.getSSLContext(for: tlsConfiguration) + tls = .prefer(sslContext) + + case .require(let tlsConfiguration): + let sslContext = try await self.getSSLContext(for: tlsConfiguration) + tls = .require(sslContext) + case .disable: + tls = .disable + } + + var connectionConfig: PostgresConnection.Configuration + switch config.endpointInfo { + case .bindUnixDomainSocket(let path): + connectionConfig = PostgresConnection.Configuration( + unixSocketPath: path, + username: config.username, + password: config.password, + database: config.database + ) + + case .connectTCP(let host, let port): + connectionConfig = PostgresConnection.Configuration( + host: host, + port: port, + username: config.username, + password: config.password, + database: config.database, + tls: tls + ) + } + + connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout) + connectionConfig.options.tlsServerName = config.options.tlsServerName + connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData + + return connectionConfig + } + + private func getSSLContext(for tlsConfiguration: TLSConfiguration) async throws -> NIOSSLContext { + enum Action { + case produce + case succeed(NIOSSLContext) + case fail(any Error) + case wait + } + + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let action = self.sslContextBox.withLockedValue { cache -> Action in + switch cache.state { + case .none: + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + + case .cached(let cachedTLSConfiguration, let context): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + return .succeed(context) + } else { + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + } + + case .failed(let cachedTLSConfiguration, let error): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + return .fail(error) + } else { + cache.state = .producing(tlsConfiguration, [continuation]) + return .produce + } + + case .producing(let cachedTLSConfiguration, var continuations): + continuations.append(continuation) + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + cache.state = .producing(cachedTLSConfiguration, continuations) + return .wait + } else { + cache.state = .producing(tlsConfiguration, continuations) + return .produce + } + } + } + + switch action { + case .wait: + break + + case .produce: + // TBD: we might want to consider moving this off the concurrent executor + self.reportProduceSSLContextResult( + Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}), + for: tlsConfiguration + ) + + case .succeed(let context): + continuation.resume(returning: context) + + case .fail(let error): + continuation.resume(throwing: error) + } + } + } + + private func reportProduceSSLContextResult(_ result: Result, for tlsConfiguration: TLSConfiguration) { + enum Action { + case fail(any Error, [CheckedContinuation]) + case succeed(NIOSSLContext, [CheckedContinuation]) + case none + } + + let action = self.sslContextBox.withLockedValue { cache -> Action in + switch cache.state { + case .none: + preconditionFailure("Invalid state: \(cache.state)") + + case .cached, .failed: + return .none + + case .producing(let cachedTLSConfiguration, let continuations): + if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { + switch result { + case .success(let context): + cache.state = .cached(cachedTLSConfiguration, context) + return .succeed(context, continuations) + + case .failure(let failure): + cache.state = .failed(cachedTLSConfiguration, failure) + return .fail(failure, continuations) + } + } else { + return .none + } + } + } + + switch action { + case .none: + break + + case .succeed(let context, let continuations): + for continuation in continuations { + continuation.resume(returning: context) + } + + case .fail(let error, let continuations): + for continuation in continuations { + continuation.resume(throwing: error) + } + } + } +} diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift new file mode 100644 index 00000000..fc5a5b00 --- /dev/null +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -0,0 +1,378 @@ +import NIOCore +import NIOSSL +import Atomics +import Logging +import _ConnectionPoolModule + +/// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's +/// behavior. +/// +/// > Important: +/// The client can only lease connections if the user is running the client's ``run()`` method in a long running task: +/// +/// ```swift +/// let client = PostgresClient(configuration: configuration, logger: logger) +/// await withTaskGroup(of: Void.self) { +/// taskGroup.addTask { +/// client.run() // !important +/// } +/// +/// taskGroup.addTask { +/// client.withConnection { connection in +/// do { +/// let rows = try await connection.query("SELECT userID, name, age FROM users;") +/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) { +/// // do something with the values +/// } +/// } catch { +/// // handle errors +/// } +/// } +/// } +/// } +/// ``` +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +@_spi(ConnectionPool) +public final class PostgresClient: Sendable { + public struct Configuration: Sendable { + public struct TLS: Sendable { + enum Base { + case disable + case prefer(NIOSSL.TLSConfiguration) + case require(NIOSSL.TLSConfiguration) + } + + var base: Base + + private init(_ base: Base) { + self.base = base + } + + /// Do not try to create a TLS connection to the server. + public static var disable: Self = Self.init(.disable) + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, create an insecure connection. + public static func prefer(_ sslContext: NIOSSL.TLSConfiguration) -> Self { + self.init(.prefer(sslContext)) + } + + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. + /// If the server does not support TLS, fail the connection creation. + public static func require(_ sslContext: NIOSSL.TLSConfiguration) -> Self { + self.init(.require(sslContext)) + } + } + + // MARK: Client options + + /// Describes general client behavior options. Those settings are considered advanced options. + public struct Options: Sendable { + /// A keep-alive behavior for Postgres connections. The ``frequency`` defines after which time an idle + /// connection shall run a keep-alive ``query``. + public struct KeepAliveBehavior: Sendable { + /// The amount of time that shall pass before an idle connection runs a keep-alive ``query``. + public var frequency: Duration + + /// The ``query`` that is run on an idle connection after it has been idle for ``frequency``. + public var query: PostgresQuery + + /// Create a new `KeepAliveBehavior`. + /// - Parameters: + /// - frequency: The amount of time that shall pass before an idle connection runs a keep-alive `query`. + /// Defaults to `30` seconds. + /// - query: The `query` that is run on an idle connection after it has been idle for `frequency`. + /// Defaults to `SELECT 1;`. + public init(frequency: Duration = .seconds(30), query: PostgresQuery = "SELECT 1;") { + self.frequency = frequency + self.query = query + } + } + + /// A timeout for creating a TCP/Unix domain socket connection. Defaults to `10` seconds. + public var connectTimeout: Duration = .seconds(10) + + /// The server name to use for certificate validation and SNI (Server Name Indication) when TLS is enabled. + /// Defaults to none (but see below). + /// + /// > When set to `nil`: + /// If the connection is made to a server over TCP using + /// ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)``, the given `host` + /// is used, unless it was an IP address string. If it _was_ an IP, or the connection is made by any other + /// method, SNI is disabled. + public var tlsServerName: String? = nil + + /// Whether the connection is required to provide backend key data (internal Postgres stuff). + /// + /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. + /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). + public var requireBackendKeyData: Bool = true + + /// The minimum number of connections that the client shall keep open at any time, even if there is no + /// demand. Default to `0`. + /// + /// If the open connection count becomes less than ``minimumConnections`` new connections + /// are created immidiatly. Must be greater or equal to zero and less than ``maximumConnections``. + /// + /// Idle connections are kept alive using the ``keepAliveBehavior``. + public var minimumConnections: Int = 0 + + /// The maximum number of connections that the client may open to the server at any time. Must be greater + /// than ``minimumConnections``. Defaults to `20` connections. + /// + /// Connections, that are created in response to demand are kept alive for the ``connectionIdleTimeout`` + /// before they are dropped. + public var maximumConnections: Int = 20 + + /// The maximum amount time that a connection that is not part of the ``minimumConnections`` is kept + /// open without being leased. Defaults to `60` seconds. + public var connectionIdleTimeout: Duration = .seconds(60) + + /// The ``KeepAliveBehavior-swift.struct`` to ensure that the underlying tcp-connection is still active + /// for idle connections. `Nil` means that the client shall not run keep alive queries to the server. Defaults to a + /// keep alive query of `SELECT 1;` every `30` seconds. + public var keepAliveBehavior: KeepAliveBehavior? = KeepAliveBehavior() + + /// Create an options structure with default values. + /// + /// Most users should not need to adjust the defaults. + public init() {} + } + + // MARK: - Accessors + + /// The hostname to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var host: String? { + if case let .connectTCP(host, _) = self.endpointInfo { return host } + else { return nil } + } + + /// The port to connect to for TCP configurations. + /// + /// Always `nil` for other configurations. + public var port: Int? { + if case let .connectTCP(_, port) = self.endpointInfo { return port } + else { return nil } + } + + /// The socket path to connect to for Unix domain socket connections. + /// + /// Always `nil` for other configurations. + public var unixSocketPath: String? { + if case let .bindUnixDomainSocket(path) = self.endpointInfo { return path } + else { return nil } + } + + /// The TLS mode to use for the connection. Valid for all configurations. + /// + /// See ``TLS-swift.struct``. + public var tls: TLS = .prefer(.makeClientConfiguration()) + + /// Options for handling the communication channel. Most users don't need to change these. + /// + /// See ``Options-swift.struct``. + public var options: Options = .init() + + /// The username to connect with. + public var username: String + + /// The password, if any, for the user specified by ``username``. + /// + /// - Warning: `nil` means "no password provided", whereas `""` (the empty string) is a password of zero + /// length; these are not the same thing. + public var password: String? + + /// The name of the database to open. + /// + /// - Note: If set to `nil` or an empty string, the provided ``username`` is used. + public var database: String? + + // MARK: - Initializers + + /// Create a configuration for connecting to a server with a hostname and optional port. + /// + /// This specifies a TCP connection. If you're unsure which kind of connection you want, you almost + /// definitely want this one. + /// + /// - Parameters: + /// - host: The hostname to connect to. + /// - port: The TCP port to connect to (defaults to 5432). + /// - tls: The TLS mode to use. + public init(host: String, port: Int = 5432, username: String, password: String?, database: String?, tls: TLS) { + self.init(endpointInfo: .connectTCP(host: host, port: port), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for connecting to a server through a UNIX domain socket. + /// + /// - Parameters: + /// - path: The filesystem path of the socket to connect to. + /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + public init(unixSocketPath: String, username: String, password: String?, database: String?) { + self.init(endpointInfo: .bindUnixDomainSocket(path: unixSocketPath), tls: .disable, username: username, password: password, database: database) + } + + // MARK: - Implementation details + + enum EndpointInfo { + case bindUnixDomainSocket(path: String) + case connectTCP(host: String, port: Int) + } + + var endpointInfo: EndpointInfo + + init(endpointInfo: EndpointInfo, tls: TLS, username: String, password: String?, database: String?) { + self.endpointInfo = endpointInfo + self.tls = tls + self.username = username + self.password = password + self.database = database + } + } + + typealias Pool = ConnectionPool< + PostgresConnection, + PostgresConnection.ID, + ConnectionIDGenerator, + ConnectionRequest, + ConnectionRequest.ID, + PostgresKeepAliveBehavor, + PostgresClientMetrics, + ContinuousClock + > + + let pool: Pool + let factory: ConnectionFactory + let runningAtomic = ManagedAtomic(false) + let backgroundLogger: Logger + + /// Creates a new ``PostgresClient``. Don't forget to run ``run()`` the client in a long running task. + /// - Parameters: + /// - configuration: The client's configuration. See ``Configuration`` for details. + /// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``. + /// - backgroundLogger: A `swift-log` `Logger` to log background messages to. A copy of this logger is also + /// forwarded to the created connections as a background logger. + public init( + configuration: Configuration, + eventLoopGroup: any EventLoopGroup = PostgresClient.defaultEventLoopGroup, + backgroundLogger: Logger + ) { + let factory = ConnectionFactory(config: configuration, eventLoopGroup: eventLoopGroup, logger: backgroundLogger) + self.factory = factory + self.backgroundLogger = backgroundLogger + + self.pool = ConnectionPool( + configuration: .init(configuration), + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: .init(configuration.options.keepAliveBehavior, logger: backgroundLogger), + observabilityDelegate: .init(logger: backgroundLogger), + clock: ContinuousClock() + ) { (connectionID, pool) in + let connection = try await factory.makeConnection(connectionID, pool: pool) + + return ConnectionAndMetadata(connection: connection, maximalStreamsOnConnection: 1) + } + } + + + /// Lease a connection for the provided `closure`'s lifetime. + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { + let connection = try await self.leaseConnection() + + defer { self.pool.releaseConnection(connection) } + + return try await closure(connection) + } + + /// The client's run method. Users must call this function in order to start the client's background task processing + /// like creating and destroying connections and running timers. + /// + /// Calls to ``withConnection(_:)`` will emit a `logger` warning, if ``run()`` hasn't been called previously. + public func run() async { + let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) + precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") + await self.pool.run() + } + + // MARK: - Private Methods - + + private func leaseConnection() async throws -> PostgresConnection { + if !self.runningAtomic.load(ordering: .relaxed) { + self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.") + } + return try await self.pool.leaseConnection() + } + + /// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform. + /// + /// This will select the concrete `EventLoopGroup` depending which platform this is running on. + public static var defaultEventLoopGroup: EventLoopGroup { + PostgresConnection.defaultEventLoopGroup + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +struct PostgresKeepAliveBehavor: ConnectionKeepAliveBehavior { + let behavior: PostgresClient.Configuration.Options.KeepAliveBehavior? + let logger: Logger + + init(_ behavior: PostgresClient.Configuration.Options.KeepAliveBehavior?, logger: Logger) { + self.behavior = behavior + self.logger = logger + } + + var keepAliveFrequency: Duration? { + self.behavior?.frequency + } + + func runKeepAlive(for connection: PostgresConnection) async throws { + try await connection.query(self.behavior!.query, logger: self.logger).map { _ in }.get() + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension ConnectionPoolConfiguration { + init(_ config: PostgresClient.Configuration) { + self = ConnectionPoolConfiguration() + self.minimumConnectionCount = config.options.minimumConnections + self.maximumConnectionSoftLimit = config.options.maximumConnections + self.maximumConnectionHardLimit = config.options.maximumConnections + self.idleTimeout = config.options.connectionIdleTimeout + } +} + +@_spi(ConnectionPool) +extension PostgresConnection: PooledConnection { + public func close() { + self.channel.close(mode: .all, promise: nil) + } + + public func onClose(_ closure: @escaping ((any Error)?) -> ()) { + self.closeFuture.whenComplete { _ in closure(nil) } + } +} + +extension ConnectionPoolError { + func mapToPSQLError(lastConnectError: Error?) -> Error { + var psqlError: PSQLError + switch self { + case .poolShutdown: + psqlError = PSQLError.poolClosed + psqlError.underlying = self + + case .requestCancelled: + psqlError = PSQLError.queryCancelled + psqlError.underlying = self + + default: + return self + } + return psqlError + } +} diff --git a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift new file mode 100644 index 00000000..aa8215db --- /dev/null +++ b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift @@ -0,0 +1,85 @@ +import _ConnectionPoolModule +import Logging + +final class PostgresClientMetrics: ConnectionPoolObservabilityDelegate { + typealias ConnectionID = PostgresConnection.ID + + let logger: Logger + + init(logger: Logger) { + self.logger = logger + } + + func startedConnecting(id: ConnectionID) { + self.logger.debug("Creating new connection", metadata: [ + .connectionID: "\(id)", + ]) + } + + /// A connection attempt failed with the given error. After some period of + /// time ``startedConnecting(id:)`` may be called again. + func connectFailed(id: ConnectionID, error: Error) { + self.logger.debug("Connection creation failed", metadata: [ + .connectionID: "\(id)", + .error: "\(String(reflecting: error))" + ]) + } + + func connectSucceeded(id: ConnectionID) { + self.logger.debug("Connection established", metadata: [ + .connectionID: "\(id)" + ]) + } + + /// The utlization of the connection changed; a stream may have been used, returned or the + /// maximum number of concurrent streams available on the connection changed. + func connectionLeased(id: ConnectionID) { + self.logger.debug("Connection leased", metadata: [ + .connectionID: "\(id)" + ]) + } + + func connectionReleased(id: ConnectionID) { + self.logger.debug("Connection released", metadata: [ + .connectionID: "\(id)" + ]) + } + + func keepAliveTriggered(id: ConnectionID) { + self.logger.debug("run ping pong", metadata: [ + .connectionID: "\(id)", + ]) + } + + func keepAliveSucceeded(id: ConnectionID) {} + + func keepAliveFailed(id: PostgresConnection.ID, error: Error) {} + + /// The remote peer is quiescing the connection: no new streams will be created on it. The + /// connection will eventually be closed and removed from the pool. + func connectionClosing(id: ConnectionID) { + self.logger.debug("Close connection", metadata: [ + .connectionID: "\(id)" + ]) + } + + /// The connection was closed. The connection may be established again in the future (notified + /// via ``startedConnecting(id:)``). + func connectionClosed(id: ConnectionID, error: Error?) { + self.logger.debug("Connection closed", metadata: [ + .connectionID: "\(id)" + ]) + } + + func requestQueueDepthChanged(_ newDepth: Int) { + + } + + func connectSucceeded(id: PostgresConnection.ID, streamCapacity: UInt16) { + + } + + func connectionUtilizationChanged(id: PostgresConnection.ID, streamsUsed: UInt16, streamCapacity: UInt16) { + + } +} diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index c4f30624..7d464c2b 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -46,6 +46,8 @@ extension PSQLError { return self.underlying ?? self case .uncleanShutdown: return PostgresError.protocol("Unexpected connection close") + case .poolClosed: + return self } } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 99b73fd0..ac0f96f4 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -95,7 +95,7 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssertEqual(releasedContext.use, .demand) XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) - let parkTimers = connections.parkConnection(at: index) + let parkTimers = connections.parkConnection(at: index, hasBecomeIdle: true) XCTAssertEqual(parkTimers, [ .init(timerID: 0, connectionID: newConnection.id, usecase: .keepAlive), .init(timerID: 1, connectionID: newConnection.id, usecase: .idleTimeout), @@ -199,7 +199,7 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { let thirdConnKeepTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: thirdRequest.connectionID, usecase: .keepAlive) let thirdConnIdleTimer = TestPoolStateMachine.ConnectionTimer(timerID: 1, connectionID: thirdRequest.connectionID, usecase: .idleTimeout) let thirdConnIdleTimerCancellationToken = MockTimerCancellationToken(thirdConnIdleTimer) - XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex), [thirdConnKeepTimer, thirdConnIdleTimer]) + XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex, hasBecomeIdle: true), [thirdConnKeepTimer, thirdConnIdleTimer]) XCTAssertNil(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer))) XCTAssertNil(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken)) @@ -277,7 +277,7 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) XCTAssertEqual(establishedConnectionContext.use, .persisted) XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) - let timers = connections.parkConnection(at: connectionIndex) + let timers = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) XCTAssertEqual(timers, [keepAliveTimer]) diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift new file mode 100644 index 00000000..b1e7f9a8 --- /dev/null +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -0,0 +1,66 @@ +@_spi(ConnectionPool) import PostgresNIO +import XCTest +import NIOPosix +import NIOSSL +import Logging +import Atomics + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class PostgresClientTests: XCTestCase { + + func testGetConnection() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + for i in 0..<10000 { + taskGroup.addTask { + try await client.withConnection() { connection in + _ = try await connection.query("SELECT 1", logger: logger) + } + print("done: \(i)") + } + } + + for _ in 0..<10000 { + _ = await taskGroup.nextResult()! + } + + taskGroup.cancelAll() + } + } +} + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +extension PostgresClient.Configuration { + static func makeTestConfiguration() -> PostgresClient.Configuration { + var tlsConfiguration = TLSConfiguration.makeClientConfiguration() + tlsConfiguration.certificateVerification = .none + var clientConfig = PostgresClient.Configuration( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: env("POSTGRES_PORT").flatMap({ Int($0) }) ?? 5432, + username: env("POSTGRES_USER") ?? "test_username", + password: env("POSTGRES_PASSWORD") ?? "test_password", + database: env("POSTGRES_DB") ?? "test_database", + tls: .prefer(tlsConfiguration) + ) + clientConfig.options.minimumConnections = 0 + clientConfig.options.maximumConnections = 12*4 + clientConfig.options.keepAliveBehavior = .init(frequency: .seconds(5)) + clientConfig.options.connectionIdleTimeout = .seconds(15) + + return clientConfig + } +} From 21473f547ab195da56dca4bd203d7d2f150c48c1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 30 Oct 2023 14:32:31 +0100 Subject: [PATCH 207/292] Remove warn-concurrency warnings (#408) --- .../PostgresConnection+Configuration.swift | 10 ++--- .../Connection/PostgresConnection.swift | 17 +++++---- .../PostgresDatabase+PreparedQuery.swift | 35 ++++++++++++----- .../Message/PostgresMessage+Error.swift | 4 +- .../New/NotificationListener.swift | 3 +- Sources/PostgresNIO/New/PSQLError.swift | 3 +- Sources/PostgresNIO/New/PSQLRowStream.swift | 8 ++-- Sources/PostgresNIO/New/PSQLTask.swift | 2 +- .../New/PostgresChannelHandler.swift | 14 +++++-- Sources/PostgresNIO/New/PostgresCodable.swift | 2 +- .../PostgresNIO/PostgresDatabase+Query.swift | 28 +++++++++----- .../PostgresDatabase+SimpleQuery.swift | 12 ++++-- Sources/PostgresNIO/PostgresDatabase.swift | 5 ++- .../Utilities/PostgresJSONDecoder.swift | 16 +++++++- .../Utilities/PostgresJSONEncoder.swift | 16 +++++++- Tests/IntegrationTests/AsyncTests.swift | 2 +- .../PSQLIntegrationTests.swift | 11 +++--- .../New/Data/JSON+PSQLCodableTests.swift | 9 +++-- .../New/PSQLRowStreamTests.swift | 38 ++++++++++--------- .../New/PostgresConnectionTests.swift | 2 +- .../Utilities/PostgresJSONCodingTests.swift | 21 ++++++---- 21 files changed, 164 insertions(+), 94 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index bc9bcfc2..22c59d8a 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -4,12 +4,12 @@ import NIOSSL extension PostgresConnection { /// A configuration object for a connection - public struct Configuration { - + public struct Configuration: Sendable { + // MARK: - TLS /// The possible modes of operation for TLS encapsulation of a connection. - public struct TLS { + public struct TLS: Sendable { // MARK: Initializers /// Do not try to create a TLS connection to the server. @@ -63,7 +63,7 @@ extension PostgresConnection { // MARK: - Connection options /// Describes options affecting how the underlying connection is made. - public struct Options { + public struct Options: Sendable { /// A timeout for connection attempts. Defaults to ten seconds. /// /// Ignored when using a preexisting communcation channel. (See @@ -219,7 +219,7 @@ extension PostgresConnection { /// the deprecated configuration. /// /// TODO: Drop with next major release - struct InternalConfiguration { + struct InternalConfiguration: Sendable { enum Connection { case unresolvedTCP(host: String, port: Int) case unresolvedUDS(path: String) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 9994ec42..f79a5555 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -144,8 +144,9 @@ public final class PostgresConnection: @unchecked Sendable { on eventLoop: EventLoop ) -> EventLoopFuture { - var logger = logger - logger[postgresMetadataKey: .connectionID] = "\(connectionID)" + var mlogger = logger + mlogger[postgresMetadataKey: .connectionID] = "\(connectionID)" + let logger = mlogger // Here we dispatch to the `eventLoop` first before we setup the EventLoopFuture chain, to // ensure all `flatMap`s are executed on the EventLoop (this means the enqueuing of the @@ -567,12 +568,13 @@ extension PostgresConnection { /// - line: The line, the query was started in. Used for better error reporting. /// - onRow: A closure that is invoked for every row. /// - Returns: An EventLoopFuture, that allows access to the future ``PostgresQueryMetadata``. + @preconcurrency public func query( _ query: PostgresQuery, logger: Logger, file: String = #fileID, line: Int = #line, - _ onRow: @escaping (PostgresRow) throws -> () + _ onRow: @escaping @Sendable (PostgresRow) throws -> () ) -> EventLoopFuture { self.queryStream(query, logger: logger).flatMap { rowStream in rowStream.onRow(onRow).flatMapThrowing { () -> PostgresQueryMetadata in @@ -638,6 +640,7 @@ extension PostgresConnection: PostgresDatabase { } } + @preconcurrency public func withConnection(_ closure: (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture { closure(self) } @@ -645,11 +648,11 @@ extension PostgresConnection: PostgresDatabase { internal enum PostgresCommands: PostgresRequest { case query(PostgresQuery, - onMetadata: (PostgresQueryMetadata) -> () = { _ in }, - onRow: (PostgresRow) throws -> ()) - case queryAll(PostgresQuery, onResult: (PostgresQueryResult) -> ()) + onMetadata: @Sendable (PostgresQueryMetadata) -> () = { _ in }, + onRow: @Sendable (PostgresRow) throws -> ()) + case queryAll(PostgresQuery, onResult: @Sendable (PostgresQueryResult) -> ()) case prepareQuery(request: PrepareQueryRequest) - case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: (PostgresRow) throws -> ()) + case executePreparedStatement(query: PreparedQuery, binds: [PostgresData], onRow: @Sendable (PostgresRow) throws -> ()) func respond(to message: PostgresMessage) throws -> [PostgresMessage]? { fatalError("This function must not be called") diff --git a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift index 074ba6de..56496172 100644 --- a/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift +++ b/Sources/PostgresNIO/Connection/PostgresDatabase+PreparedQuery.swift @@ -1,4 +1,5 @@ import NIOCore +import NIOConcurrencyHelpers import struct Foundation.UUID extension PostgresDatabase { @@ -14,7 +15,8 @@ extension PostgresDatabase { } } - public func prepare(query: String, handler: @escaping (PreparedQuery) -> EventLoopFuture<[[PostgresRow]]>) -> EventLoopFuture<[[PostgresRow]]> { + @preconcurrency + public func prepare(query: String, handler: @Sendable @escaping (PreparedQuery) -> EventLoopFuture<[[PostgresRow]]>) -> EventLoopFuture<[[PostgresRow]]> { prepare(query: query) .flatMap { preparedQuery in handler(preparedQuery) @@ -26,7 +28,7 @@ extension PostgresDatabase { } -public struct PreparedQuery { +public struct PreparedQuery: Sendable { let underlying: PSQLPreparedStatement let database: PostgresDatabase @@ -36,11 +38,16 @@ public struct PreparedQuery { } public func execute(_ binds: [PostgresData] = []) -> EventLoopFuture<[PostgresRow]> { - var rows: [PostgresRow] = [] - return self.execute(binds) { rows.append($0) }.map { rows } + let rowsBoxed = NIOLockedValueBox([PostgresRow]()) + return self.execute(binds) { row in + rowsBoxed.withLockedValue { + $0.append(row) + } + }.map { rowsBoxed.withLockedValue { $0 } } } - public func execute(_ binds: [PostgresData] = [], _ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + @preconcurrency + public func execute(_ binds: [PostgresData] = [], _ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { let command = PostgresCommands.executePreparedStatement(query: self, binds: binds, onRow: onRow) return self.database.send(command, logger: self.database.logger) } @@ -50,15 +57,23 @@ public struct PreparedQuery { } } -final class PrepareQueryRequest { +final class PrepareQueryRequest: Sendable { let query: String let name: String - var prepared: PreparedQuery? = nil - - + var prepared: PreparedQuery? { + get { + self._prepared.withLockedValue { $0 } + } + set { + self._prepared.withLockedValue { + $0 = newValue + } + } + } + let _prepared: NIOLockedValueBox = .init(nil) + init(_ query: String, as name: String) { self.query = query self.name = name } - } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift index 44f9e6bf..45cda21f 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift @@ -2,8 +2,8 @@ import NIOCore extension PostgresMessage { /// First message sent from the frontend during startup. - public struct Error: CustomStringConvertible { - public enum Field: UInt8, Hashable { + public struct Error: CustomStringConvertible, Sendable { + public enum Field: UInt8, Hashable, Sendable { /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a //// localized translation of one of these. Always present. diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift index 5f4bc3de..9e47ff34 100644 --- a/Sources/PostgresNIO/New/NotificationListener.swift +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -44,6 +44,7 @@ final class NotificationListener: @unchecked Sendable { func startListeningSucceeded(handler: PostgresChannelHandler) { self.eventLoop.preconditionInEventLoop() + let handlerLoopBound = NIOLoopBound(handler, eventLoop: self.eventLoop) switch self.state { case .streamInitialized(let checkedContinuation): @@ -55,7 +56,7 @@ final class NotificationListener: @unchecked Sendable { switch reason { case .cancelled: eventLoop.execute { - handler.cancelNotificationListener(channel: channel, id: listenerID) + handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID) } case .finished: diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 81099043..4a9f9216 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -1,7 +1,8 @@ import NIOCore /// An error that is thrown from the PostgresClient. -public struct PSQLError: Error { +/// Sendability enforced through Copy on Write semantics +public struct PSQLError: Error, @unchecked Sendable { public struct Code: Sendable, Hashable, CustomStringConvertible { enum Base: Sendable, Hashable { diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b008d185..b3dfea30 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -96,10 +96,8 @@ final class PSQLRowStream: @unchecked Sendable { let yieldResult = source.yield(contentsOf: bufferedRows) self.downstreamState = .asyncSequence(source, dataSource) - self.eventLoop.execute { - self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) - } - + self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) + case .finished(let buffer, let commandTag): _ = source.yield(contentsOf: buffer) source.finish() @@ -206,7 +204,7 @@ final class PSQLRowStream: @unchecked Sendable { // MARK: Consume on EventLoop - func onRow(_ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + func onRow(_ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { if self.eventLoop.inEventLoop { return self.onRow0(onRow) } else { diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 9425c12b..6308a5b3 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -70,7 +70,7 @@ final class ExtendedQueryContext { } } -final class PreparedStatementContext{ +final class PreparedStatementContext: Sendable { let name: String let sql: String let bindings: PostgresBindings diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 6d9d08b3..9d0ef2a5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -597,8 +597,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { logger: self.logger, promise: promise ) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) promise.futureResult.whenComplete { result in - self.startListenCompleted(result, for: channel, context: context) + let (selfTransferred, context) = loopBound.value + selfTransferred.startListenCompleted(result, for: channel, context: context) } return .extendedQuery(query) @@ -643,8 +645,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { logger: self.logger, promise: promise ) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) promise.futureResult.whenComplete { result in - self.stopListenCompleted(result, for: channel, context: context) + let (selfTransferred, context) = loopBound.value + selfTransferred.stopListenCompleted(result, for: channel, context: context) } return .extendedQuery(query) @@ -693,10 +697,12 @@ final class PostgresChannelHandler: ChannelDuplexHandler { context: ChannelHandlerContext ) -> PSQLTask { let promise = self.eventLoop.makePromise(of: RowDescription?.self) + let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) promise.futureResult.whenComplete { result in + let (selfTransferred, context) = loopBound.value switch result { case .success(let rowDescription): - self.prepareStatementComplete( + selfTransferred.prepareStatementComplete( name: preparedStatement.name, rowDescription: rowDescription, context: context @@ -708,7 +714,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } else { psqlError = .connectionError(underlying: error) } - self.prepareStatementFailed( + selfTransferred.prepareStatementFailed( name: preparedStatement.name, error: psqlError, context: context diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 53dbd708..71c689bf 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -188,7 +188,7 @@ extension PostgresEncodingContext where JSONEncoder == Foundation.JSONEncoder { /// A context that is passed to Swift objects that are decoded from the Postgres wire format. Used /// to pass further information to the decoding method. -public struct PostgresDecodingContext { +public struct PostgresDecodingContext: Sendable { /// A ``PostgresJSONDecoder`` used to decode the object from json. public var jsonDecoder: JSONDecoder diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 95abb6fc..01a7e61f 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -1,27 +1,35 @@ import NIOCore import Logging +import NIOConcurrencyHelpers extension PostgresDatabase { public func query( _ string: String, _ binds: [PostgresData] = [] ) -> EventLoopFuture { - var rows: [PostgresRow] = [] - var metadata: PostgresQueryMetadata? - return self.query(string, binds, onMetadata: { - metadata = $0 - }) { - rows.append($0) + let box = NIOLockedValueBox((metadata: PostgresQueryMetadata?.none, rows: [PostgresRow]())) + + return self.query(string, binds, onMetadata: { metadata in + box.withLockedValue { + $0.metadata = metadata + } + }) { row in + box.withLockedValue { + $0.rows.append(row) + } }.map { - .init(metadata: metadata!, rows: rows) + box.withLockedValue { + PostgresQueryResult(metadata: $0.metadata!, rows: $0.rows) + } } } + @preconcurrency public func query( _ string: String, _ binds: [PostgresData] = [], - onMetadata: @escaping (PostgresQueryMetadata) -> () = { _ in }, - onRow: @escaping (PostgresRow) throws -> () + onMetadata: @Sendable @escaping (PostgresQueryMetadata) -> () = { _ in }, + onRow: @Sendable @escaping (PostgresRow) throws -> () ) -> EventLoopFuture { var bindings = PostgresBindings(capacity: binds.count) binds.forEach { bindings.append($0) } @@ -58,7 +66,7 @@ extension PostgresQueryResult: Collection { } } -public struct PostgresQueryMetadata { +public struct PostgresQueryMetadata: Sendable { public let command: String public var oid: Int? public var rows: Int? diff --git a/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift b/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift index 77f3d034..5cf2d7a4 100644 --- a/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift +++ b/Sources/PostgresNIO/PostgresDatabase+SimpleQuery.swift @@ -1,13 +1,19 @@ import NIOCore +import NIOConcurrencyHelpers import Logging extension PostgresDatabase { public func simpleQuery(_ string: String) -> EventLoopFuture<[PostgresRow]> { - var rows: [PostgresRow] = [] - return simpleQuery(string) { rows.append($0) }.map { rows } + let rowsBoxed = NIOLockedValueBox([PostgresRow]()) + return self.simpleQuery(string) { row in + rowsBoxed.withLockedValue { + $0.append(row) + } + }.map { rowsBoxed.withLockedValue { $0 } } } - public func simpleQuery(_ string: String, _ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { + @preconcurrency + public func simpleQuery(_ string: String, _ onRow: @Sendable @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { self.query(string, onRow: onRow) } } diff --git a/Sources/PostgresNIO/PostgresDatabase.swift b/Sources/PostgresNIO/PostgresDatabase.swift index 64e44abb..fcd1afc7 100644 --- a/Sources/PostgresNIO/PostgresDatabase.swift +++ b/Sources/PostgresNIO/PostgresDatabase.swift @@ -1,14 +1,15 @@ import NIOCore import Logging -public protocol PostgresDatabase { +@preconcurrency +public protocol PostgresDatabase: Sendable { var logger: Logger { get } var eventLoop: EventLoop { get } func send( _ request: PostgresRequest, logger: Logger ) -> EventLoopFuture - + func withConnection(_ closure: @escaping (PostgresConnection) -> EventLoopFuture) -> EventLoopFuture } diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift index fb7b4e8d..ba57ee9b 100644 --- a/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift +++ b/Sources/PostgresNIO/Utilities/PostgresJSONDecoder.swift @@ -2,11 +2,13 @@ import class Foundation.JSONDecoder import struct Foundation.Data import NIOFoundationCompat import NIOCore +import NIOConcurrencyHelpers /// A protocol that mimicks the Foundation `JSONDecoder.decode(_:from:)` function. /// Conform a non-Foundation JSON decoder to this protocol if you want PostgresNIO to be /// able to use it when decoding JSON & JSONB values (see `PostgresNIO._defaultJSONDecoder`) -public protocol PostgresJSONDecoder { +@preconcurrency +public protocol PostgresJSONDecoder: Sendable { func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T @@ -20,10 +22,20 @@ extension PostgresJSONDecoder { } } +//@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) extension JSONDecoder: PostgresJSONDecoder {} +private let jsonDecoderLocked: NIOLockedValueBox = NIOLockedValueBox(JSONDecoder()) + /// The default JSON decoder used by PostgresNIO when decoding JSON & JSONB values. /// As `_defaultJSONDecoder` will be reused for decoding all JSON & JSONB values /// from potentially multiple threads at once, you must ensure your custom JSON decoder is /// thread safe internally like `Foundation.JSONDecoder`. -public var _defaultJSONDecoder: PostgresJSONDecoder = JSONDecoder() +public var _defaultJSONDecoder: PostgresJSONDecoder { + set { + jsonDecoderLocked.withLockedValue { $0 = newValue } + } + get { + jsonDecoderLocked.withLockedValue { $0 } + } +} diff --git a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift index 735e4b14..9585f20b 100644 --- a/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift +++ b/Sources/PostgresNIO/Utilities/PostgresJSONEncoder.swift @@ -1,11 +1,13 @@ import Foundation import NIOFoundationCompat import NIOCore +import NIOConcurrencyHelpers /// A protocol that mimicks the Foundation `JSONEncoder.encode(_:)` function. /// Conform a non-Foundation JSON encoder to this protocol if you want PostgresNIO to be /// able to use it when encoding JSON & JSONB values (see `PostgresNIO._defaultJSONEncoder`) -public protocol PostgresJSONEncoder { +@preconcurrency +public protocol PostgresJSONEncoder: Sendable { func encode(_ value: T) throws -> Data where T : Encodable func encode(_ value: T, into buffer: inout ByteBuffer) throws @@ -20,8 +22,18 @@ extension PostgresJSONEncoder { extension JSONEncoder: PostgresJSONEncoder {} +private let jsonEncoderLocked: NIOLockedValueBox = NIOLockedValueBox(JSONEncoder()) + /// The default JSON encoder used by PostgresNIO when encoding JSON & JSONB values. /// As `_defaultJSONEncoder` will be reused for encoding all JSON & JSONB values /// from potentially multiple threads at once, you must ensure your custom JSON encoder is /// thread safe internally like `Foundation.JSONEncoder`. -public var _defaultJSONEncoder: PostgresJSONEncoder = JSONEncoder() +public var _defaultJSONEncoder: PostgresJSONEncoder { + set { + jsonEncoderLocked.withLockedValue { $0 = newValue } + } + get { + jsonEncoderLocked.withLockedValue { $0 } + } +} + diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 5c77ba29..91b5656c 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -323,7 +323,7 @@ final class AsyncPostgresConnectionTests: XCTestCase { let eventLoop = eventLoopGroup.next() struct TestPreparedStatement: PostgresPreparedStatement { - static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" typealias Row = (Int, String) var state: String diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 0550dc77..57939c06 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -1,3 +1,4 @@ +import Atomics import XCTest import Logging import PostgresNIO @@ -73,19 +74,17 @@ final class IntegrationTests: XCTestCase { defer { XCTAssertNoThrow(try conn?.close().wait()) } var metadata: PostgresQueryMetadata? - var received: Int64 = 0 + let received = ManagedAtomic(0) XCTAssertNoThrow(metadata = try conn?.query("SELECT generate_series(1, 10000);", logger: .psqlTest) { row in func workaround() { - var number: Int64? - XCTAssertNoThrow(number = try row.decode(Int64.self, context: .default)) - received += 1 - XCTAssertEqual(number, received) + let expected = received.wrappingIncrementThenLoad(ordering: .relaxed) + XCTAssertEqual(expected, try row.decode(Int64.self, context: .default)) } workaround() }.wait()) - XCTAssertEqual(received, 10000) + XCTAssertEqual(received.load(ordering: .relaxed), 10000) XCTAssertEqual(metadata?.command, "SELECT") XCTAssertEqual(metadata?.rows, 10000) } diff --git a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift index 858b6ede..52dead6a 100644 --- a/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/JSON+PSQLCodableTests.swift @@ -1,4 +1,5 @@ import XCTest +import Atomics import NIOCore @testable import PostgresNIO @@ -69,11 +70,11 @@ class JSON_PSQLCodableTests: XCTestCase { } func testCustomEncoderIsUsed() { - class TestEncoder: PostgresJSONEncoder { - var encodeHits = 0 + final class TestEncoder: PostgresJSONEncoder { + let encodeHits = ManagedAtomic(0) func encode(_ value: T, into buffer: inout ByteBuffer) throws where T : Encodable { - self.encodeHits += 1 + self.encodeHits.wrappingIncrement(ordering: .relaxed) } func encode(_ value: T) throws -> Data where T : Encodable { @@ -85,6 +86,6 @@ class JSON_PSQLCodableTests: XCTestCase { let encoder = TestEncoder() var buffer = ByteBuffer() XCTAssertNoThrow(try hello.encode(into: &buffer, context: .init(jsonEncoder: encoder))) - XCTAssertEqual(encoder.encodeHits, 1) + XCTAssertEqual(encoder.encodeHits.load(ordering: .relaxed), 1) } } diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index d6d03107..9a1e9e41 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -1,3 +1,4 @@ +import Atomics import NIOCore import Logging import XCTest @@ -128,12 +129,12 @@ final class PSQLRowStreamTests: XCTestCase { XCTAssertEqual(dataSource.hitDemand, 0) // attach consumer - var counter = 0 + let counter = ManagedAtomic(0) let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") - counter += 1 + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") } - XCTAssertEqual(counter, 2) + XCTAssertEqual(counter.load(ordering: .relaxed), 2) XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertNoThrow(try future.wait()) @@ -155,7 +156,9 @@ final class PSQLRowStreamTests: XCTestCase { stream.receive([ [ByteBuffer(string: "0")], - [ByteBuffer(string: "1")] + [ByteBuffer(string: "1")], + [ByteBuffer(string: "2")], + [ByteBuffer(string: "3")], ]) stream.receive(completion: .success("SELECT 2")) @@ -163,15 +166,15 @@ final class PSQLRowStreamTests: XCTestCase { XCTAssertEqual(dataSource.hitDemand, 0) // attach consumer - var counter = 0 + let counter = ManagedAtomic(0) let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") - if counter == 1 { - throw OnRowError(row: counter) + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") + if expected == 1 { + throw OnRowError(row: expected) } - counter += 1 } - XCTAssertEqual(counter, 1) + XCTAssertEqual(counter.load(ordering: .relaxed), 2) // one more than where we excited, because we already incremented XCTAssertEqual(dataSource.hitDemand, 0) XCTAssertThrowsError(try future.wait()) { @@ -179,7 +182,6 @@ final class PSQLRowStreamTests: XCTestCase { } } - func testOnRowBeforeStreamHasFinished() { let dataSource = CountingDataSource() let stream = PSQLRowStream( @@ -201,26 +203,26 @@ final class PSQLRowStreamTests: XCTestCase { XCTAssertEqual(dataSource.hitDemand, 0, "Before we have a consumer demand is not signaled") // attach consumer - var counter = 0 + let counter = ManagedAtomic(0) let future = stream.onRow { row in - XCTAssertEqual(try row.decode(String.self, context: .default), "\(counter)") - counter += 1 + let expected = counter.loadThenWrappingIncrement(ordering: .relaxed) + XCTAssertEqual(try row.decode(String.self, context: .default), "\(expected)") } - XCTAssertEqual(counter, 2) + XCTAssertEqual(counter.load(ordering: .relaxed), 2) XCTAssertEqual(dataSource.hitDemand, 1) stream.receive([ [ByteBuffer(string: "2")], [ByteBuffer(string: "3")] ]) - XCTAssertEqual(counter, 4) + XCTAssertEqual(counter.load(ordering: .relaxed), 4) XCTAssertEqual(dataSource.hitDemand, 2) stream.receive([ [ByteBuffer(string: "4")], [ByteBuffer(string: "5")] ]) - XCTAssertEqual(counter, 6) + XCTAssertEqual(counter.load(ordering: .relaxed), 6) XCTAssertEqual(dataSource.hitDemand, 3) stream.receive(completion: .success("SELECT 6")) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 59917c40..3b1a8ca9 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -304,7 +304,7 @@ class PostgresConnectionTests: XCTestCase { } struct TestPrepareStatement: PostgresPreparedStatement { - static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" + static let sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" typealias Row = String var state: String diff --git a/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift b/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift index 2aad52b6..c6f876f2 100644 --- a/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift +++ b/Tests/PostgresNIOTests/Utilities/PostgresJSONCodingTests.swift @@ -1,3 +1,4 @@ +import Atomics import NIOCore import XCTest import PostgresNIO @@ -10,9 +11,9 @@ class PostgresJSONCodingTests: XCTestCase { PostgresNIO._defaultJSONEncoder = previousDefaultJSONEncoder } final class CustomJSONEncoder: PostgresJSONEncoder { - var didEncode = false + let counter = ManagedAtomic(0) func encode(_ value: T) throws -> Data where T : Encodable { - self.didEncode = true + self.counter.wrappingIncrement(ordering: .relaxed) return try JSONEncoder().encode(value) } } @@ -21,14 +22,16 @@ class PostgresJSONCodingTests: XCTestCase { var bar: Int } let customJSONEncoder = CustomJSONEncoder() + XCTAssertEqual(customJSONEncoder.counter.load(ordering: .relaxed), 0) PostgresNIO._defaultJSONEncoder = customJSONEncoder XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) - XCTAssert(customJSONEncoder.didEncode) + XCTAssertEqual(customJSONEncoder.counter.load(ordering: .relaxed), 1) let customJSONBEncoder = CustomJSONEncoder() + XCTAssertEqual(customJSONBEncoder.counter.load(ordering: .relaxed), 0) PostgresNIO._defaultJSONEncoder = customJSONBEncoder XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2))) - XCTAssert(customJSONBEncoder.didEncode) + XCTAssertEqual(customJSONBEncoder.counter.load(ordering: .relaxed), 1) } // https://github.com/vapor/postgres-nio/issues/126 @@ -38,9 +41,9 @@ class PostgresJSONCodingTests: XCTestCase { PostgresNIO._defaultJSONDecoder = previousDefaultJSONDecoder } final class CustomJSONDecoder: PostgresJSONDecoder { - var didDecode = false + let counter = ManagedAtomic(0) func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable { - self.didDecode = true + self.counter.wrappingIncrement(ordering: .relaxed) return try JSONDecoder().decode(type, from: data) } } @@ -49,13 +52,15 @@ class PostgresJSONCodingTests: XCTestCase { var bar: Int } let customJSONDecoder = CustomJSONDecoder() + XCTAssertEqual(customJSONDecoder.counter.load(ordering: .relaxed), 0) PostgresNIO._defaultJSONDecoder = customJSONDecoder XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) - XCTAssert(customJSONDecoder.didDecode) + XCTAssertEqual(customJSONDecoder.counter.load(ordering: .relaxed), 1) let customJSONBDecoder = CustomJSONDecoder() + XCTAssertEqual(customJSONBDecoder.counter.load(ordering: .relaxed), 0) PostgresNIO._defaultJSONDecoder = customJSONBDecoder XCTAssertNoThrow(try PostgresData(json: Object(foo: 1, bar: 2)).json(as: Object.self)) - XCTAssert(customJSONBDecoder.didDecode) + XCTAssertEqual(customJSONBDecoder.counter.load(ordering: .relaxed), 1) } } From c8269926eb3b705b70aff1975860e357760123c8 Mon Sep 17 00:00:00 2001 From: Tim Condon <0xTim@users.noreply.github.com> Date: Thu, 2 Nov 2023 12:48:52 +0000 Subject: [PATCH 208/292] Update README.md (#434) Point documentation links to our docs as that's where we host them now --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 489d0e29..6f289673 100644 --- a/README.md +++ b/README.md @@ -176,20 +176,20 @@ Some queries do not receive any rows from the server (most often `INSERT`, `UPDA Please see [SECURITY.md] for details on the security process. [SSWG Incubation]: https://github.com/swift-server/sswg/blob/main/process/incubation.md#graduated-level -[Documentation]: https://swiftpackageindex.com/vapor/postgres-nio/documentation +[Documentation]: https://api.vapor.codes/postgresnio/documentation/postgresnio [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions [Swift 5.7]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md -[`PostgresConnection`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/ -[`query(_:logger:)`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresconnection/query(_:logger:file:line:)-9mkfn -[`PostgresQuery`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresquery/ -[`PostgresRow`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresrow/ -[`PostgresRowSequence`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresrowsequence/ -[`PostgresDecodable`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresdecodable/ -[`PostgresEncodable`]: https://swiftpackageindex.com/vapor/postgres-nio/documentation/postgresnio/postgresencodable/ +[`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection +[`query(_:logger:)`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection/query(_:logger:file:line:)-9mkfn +[`PostgresQuery`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresquery +[`PostgresRow`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrow +[`PostgresRowSequence`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrowsequence +[`PostgresDecodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresdecodable +[`PostgresEncodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresencodable [PostgresKit]: https://github.com/vapor/postgres-kit From 036931d968aab819f5e380a932237118ac4e87ba Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 10 Nov 2023 18:15:46 +0100 Subject: [PATCH 209/292] Fixes Crash in ConnectionPoolStateMachine (#438) - Correctly handle Connection closes while running a keep alive (fix: #436) - Add further keep alive tests - Restructure MockClock quite a bit --- .../PoolStateMachine+ConnectionGroup.swift | 14 +- .../PoolStateMachine+ConnectionState.swift | 19 ++- .../ConnectionPoolTests.swift | 158 +++++++++++++++++- .../Mocks/MockClock.swift | 77 ++++----- .../Mocks/MockConnection.swift | 89 ---------- .../Mocks/MockConnectionFactory.swift | 92 ++++++++++ .../Mocks/MockPingPongBehaviour.swift | 65 ++++++- ...oolStateMachine+ConnectionStateTests.swift | 2 +- 8 files changed, 356 insertions(+), 160 deletions(-) create mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index b53f8d68..fabc3009 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -385,7 +385,8 @@ extension PoolStateMachine { @inlinable mutating func keepAliveSucceeded(_ connectionID: Connection.ID) -> (Int, AvailableConnectionContext)? { guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { - preconditionFailure("A connection that we don't know was released? Something is very wrong...") + // keepAliveSucceeded can race against, closeIfIdle, shutdowns or connection errors + return nil } guard let connectionInfo = self.connections[index].keepAliveSucceeded() else { @@ -430,15 +431,8 @@ extension PoolStateMachine { self.stats.idle -= 1 self.stats.closing += 1 - -// if idleState.runningKeepAlive { -// self.stats.runningKeepAlive -= 1 -// if self.keepAliveReducesAvailableStreams { -// self.stats.availableStreams += 1 -// } -// } - - self.stats.availableStreams -= closeAction.maxStreams + self.stats.runningKeepAlive -= closeAction.runningKeepAlive ? 1 : 0 + self.stats.availableStreams -= closeAction.maxStreams - closeAction.usedStreams return CloseAction( connection: closeAction.connection!, diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index a56b87da..94196a09 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -496,6 +496,9 @@ extension PoolStateMachine { var usedStreams: UInt16 @usableFromInline var maxStreams: UInt16 + @usableFromInline + var runningKeepAlive: Bool + @inlinable init( @@ -503,13 +506,15 @@ extension PoolStateMachine { previousConnectionState: PreviousConnectionState, cancelTimers: Max2Sequence, usedStreams: UInt16, - maxStreams: UInt16 + maxStreams: UInt16, + runningKeepAlive: Bool ) { self.connection = connection self.previousConnectionState = previousConnectionState self.cancelTimers = cancelTimers self.usedStreams = usedStreams self.maxStreams = maxStreams + self.runningKeepAlive = runningKeepAlive } } @@ -526,7 +531,8 @@ extension PoolStateMachine { idleTimerState?.cancellationContinuation ), usedStreams: keepAlive.usedStreams, - maxStreams: maxStreams + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning ) case .leased, .closed: @@ -559,7 +565,8 @@ extension PoolStateMachine { idleTimerState?.cancellationContinuation ), usedStreams: keepAlive.usedStreams, - maxStreams: maxStreams + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning ) case .leased(let connection, usedStreams: let usedStreams, maxStreams: let maxStreams, var keepAlive): @@ -571,7 +578,8 @@ extension PoolStateMachine { keepAlive.cancelTimerIfScheduled() ), usedStreams: keepAlive.usedStreams + usedStreams, - maxStreams: maxStreams + maxStreams: maxStreams, + runningKeepAlive: keepAlive.isRunning ) case .backingOff(let timer): @@ -581,7 +589,8 @@ extension PoolStateMachine { previousConnectionState: .backingOff, cancelTimers: Max2Sequence(timer.cancellationContinuation), usedStreams: 0, - maxStreams: 0 + maxStreams: 0, + runningKeepAlive: false ) } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 5be12a1c..57980711 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -16,7 +16,7 @@ final class ConnectionPoolTests: XCTestCase { configuration: config, idGenerator: ConnectionIDGenerator(), requestType: ConnectionRequest.self, - keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), clock: ContinuousClock() ) { @@ -74,7 +74,7 @@ final class ConnectionPoolTests: XCTestCase { configuration: config, idGenerator: ConnectionIDGenerator(), requestType: ConnectionRequest.self, - keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), clock: clock ) { @@ -119,7 +119,7 @@ final class ConnectionPoolTests: XCTestCase { configuration: config, idGenerator: ConnectionIDGenerator(), requestType: ConnectionRequest.self, - keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), clock: clock ) { @@ -135,7 +135,7 @@ final class ConnectionPoolTests: XCTestCase { throw ConnectionCreationError() } - await clock.timerScheduled() + await clock.nextTimerScheduled() taskGroup.cancelAll() } @@ -156,7 +156,7 @@ final class ConnectionPoolTests: XCTestCase { configuration: config, idGenerator: ConnectionIDGenerator(), requestType: ConnectionRequest.self, - keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), clock: ContinuousClock() ) { @@ -220,6 +220,154 @@ final class ConnectionPoolTests: XCTestCase { XCTAssert(hasFinished.load(ordering: .relaxed)) XCTAssertEqual(factory.runningConnections.count, 0) } + + func testKeepAliveWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + // move clock forward to keep alive + let newTime = clock.now.advanced(by: keepAliveDuration) + clock.advance(to: newTime) + print("clock advanced to: \(newTime)") + + await keepAlive.nextKeepAlive { keepAliveConnection in + defer { print("keep alive 1 has run") } + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + // keep alive 2 + + let deadline3 = await clock.nextTimerScheduled() + XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + print(deadline3) + + // race keep alive vs timeout + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + taskGroup.cancelAll() + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testKeepAliveWorksRacesAgainstShutdown() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + await keepAlive.nextKeepAlive { keepAliveConnection in + defer { print("keep alive 1 has run") } + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + taskGroup.cancelAll() + print("cancelled") + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift index 573ff073..cd08d54e 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift @@ -1,5 +1,6 @@ @testable import _ConnectionPoolModule import Atomics +import DequeModule @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) final class MockClock: Clock { @@ -34,19 +35,19 @@ final class MockClock: Clock { var sleepersHeap: Array - var waitersHeap: Array + var waiters: Deque + var nextDeadlines: Deque init() { self.now = .init(.seconds(0)) self.sleepersHeap = Array() - self.waitersHeap = Array() + self.waiters = Deque() + self.nextDeadlines = Deque() } } private struct Waiter { - var expectedSleepers: Int - - var continuation: CheckedContinuation + var continuation: CheckedContinuation } private struct Sleeper { @@ -77,39 +78,34 @@ final class MockClock: Clock { case cancel } - let action = self.stateBox.withLockedValue { state -> (SleepAction, ArraySlice) in - state.waitersHeap = state.waitersHeap.map { waiter in - var waiter = waiter; waiter.expectedSleepers -= 1; return waiter - } - let slice: ArraySlice - let lastRemainingIndex = state.waitersHeap.firstIndex(where: { $0.expectedSleepers > 0 }) - if let lastRemainingIndex { - slice = state.waitersHeap[0.. (SleepAction, Waiter?) in + let waiter: Waiter? + if let next = state.waiters.popFirst() { + waiter = next } else { - slice = [] + state.nextDeadlines.append(deadline) + waiter = nil } if Task.isCancelled { - return (.cancel, slice) + return (.cancel, waiter) } if state.now >= deadline { - return (.resume, slice) + return (.resume, waiter) } - let newWaiter = Sleeper(id: waiterID, deadline: deadline, continuation: continuation) + let newSleeper = Sleeper(id: waiterID, deadline: deadline, continuation: continuation) if let index = state.sleepersHeap.lastIndex(where: { $0.deadline < deadline }) { - state.sleepersHeap.insert(newWaiter, at: index + 1) + state.sleepersHeap.insert(newSleeper, at: index + 1) + } else if let first = state.sleepersHeap.first, first.deadline > deadline { + state.sleepersHeap.insert(newSleeper, at: 0) } else { - state.sleepersHeap.append(newWaiter) + state.sleepersHeap.append(newSleeper) } - return (.none, slice) + return (.none, waiter) } switch action.0 { @@ -121,9 +117,7 @@ final class MockClock: Clock { break } - for waiter in action.1 { - waiter.continuation.resume() - } + action.1?.continuation.resume(returning: deadline) } } onCancel: { let continuation = self.stateBox.withLockedValue { state -> CheckedContinuation? in @@ -136,28 +130,21 @@ final class MockClock: Clock { } } - func timerScheduled(n: Int = 1) async { - precondition(n >= 1, "At least one new sleep must be awaited") - await withCheckedContinuation { (continuation: CheckedContinuation<(), Never>) in - let result = self.stateBox.withLockedValue { state -> Bool in - let n = n - state.sleepersHeap.count - - if n <= 0 { - return true - } - - let waiter = Waiter(expectedSleepers: n, continuation: continuation) - - if let index = state.waitersHeap.firstIndex(where: { $0.expectedSleepers > n }) { - state.waitersHeap.insert(waiter, at: index) + @discardableResult + func nextTimerScheduled() async -> Instant { + await withCheckedContinuation { (continuation: CheckedContinuation) in + let instant = self.stateBox.withLockedValue { state -> Instant? in + if let scheduled = state.nextDeadlines.popFirst() { + return scheduled } else { - state.waitersHeap.append(waiter) + let waiter = Waiter(continuation: continuation) + state.waiters.append(waiter) + return nil } - return false } - if result { - continuation.resume() + if let instant { + continuation.resume(returning: instant) } } } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift index 0fa382f7..49bcc23a 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift @@ -73,92 +73,3 @@ final class MockConnection: PooledConnection, @unchecked Sendable { } } -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockConnectionFactory where Clock.Duration == Duration { - typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator - typealias Request = ConnectionRequest - typealias KeepAliveBehavior = MockPingPongBehavior - typealias MetricsDelegate = NoOpConnectionPoolMetrics - typealias ConnectionID = Int - typealias Connection = MockConnection - - let stateBox = NIOLockedValueBox(State()) - - struct State { - var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() - - var waiter = Deque), Never>>() - - var runningConnections = [ConnectionID: Connection]() - } - - var pendingConnectionAttemptsCount: Int { - self.stateBox.withLockedValue { $0.attempts.count } - } - - var runningConnections: [Connection] { - self.stateBox.withLockedValue { Array($0.runningConnections.values) } - } - - func makeConnection( - id: Int, - for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> - ) async throws -> ConnectionAndMetadata { - // we currently don't support cancellation when creating a connection - let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in - let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in - if let waiter = state.waiter.popFirst() { - return waiter - } else { - state.attempts.append((id, checkedContinuation)) - return nil - } - } - - if let waiter { - waiter.resume(returning: (id, checkedContinuation)) - } - } - - return .init(connection: result.0, maximalStreamsOnConnection: result.1) - } - - @discardableResult - func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { - let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in - let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in - if let attempt = state.attempts.popFirst() { - return attempt - } else { - state.waiter.append(continuation) - return nil - } - } - - if let attempt { - continuation.resume(returning: attempt) - } - } - - do { - let streamCount = try await closure(connectionID) - let connection = MockConnection(id: connectionID) - - connection.onClose { _ in - self.stateBox.withLockedValue { state in - _ = state.runningConnections.removeValue(forKey: connectionID) - } - } - - self.stateBox.withLockedValue { state in - _ = state.runningConnections[connectionID] = connection - } - - continuation.resume(returning: (connection, streamCount)) - return connection - } catch { - continuation.resume(throwing: error) - throw error - } - } -} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift new file mode 100644 index 00000000..b0c94467 --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift @@ -0,0 +1,92 @@ +@testable import _ConnectionPoolModule +import DequeModule + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +final class MockConnectionFactory where Clock.Duration == Duration { + typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + typealias Request = ConnectionRequest + typealias KeepAliveBehavior = MockPingPongBehavior + typealias MetricsDelegate = NoOpConnectionPoolMetrics + typealias ConnectionID = Int + typealias Connection = MockConnection + + let stateBox = NIOLockedValueBox(State()) + + struct State { + var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>() + + var waiter = Deque), Never>>() + + var runningConnections = [ConnectionID: Connection]() + } + + var pendingConnectionAttemptsCount: Int { + self.stateBox.withLockedValue { $0.attempts.count } + } + + var runningConnections: [Connection] { + self.stateBox.withLockedValue { Array($0.runningConnections.values) } + } + + func makeConnection( + id: Int, + for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> + ) async throws -> ConnectionAndMetadata { + // we currently don't support cancellation when creating a connection + let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in + let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in + if let waiter = state.waiter.popFirst() { + return waiter + } else { + state.attempts.append((id, checkedContinuation)) + return nil + } + } + + if let waiter { + waiter.resume(returning: (id, checkedContinuation)) + } + } + + return .init(connection: result.0, maximalStreamsOnConnection: result.1) + } + + @discardableResult + func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in + let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in + if let attempt = state.attempts.popFirst() { + return attempt + } else { + state.waiter.append(continuation) + return nil + } + } + + if let attempt { + continuation.resume(returning: attempt) + } + } + + do { + let streamCount = try await closure(connectionID) + let connection = MockConnection(id: connectionID) + + connection.onClose { _ in + self.stateBox.withLockedValue { state in + _ = state.runningConnections.removeValue(forKey: connectionID) + } + } + + self.stateBox.withLockedValue { state in + _ = state.runningConnections[connectionID] = connection + } + + continuation.resume(returning: (connection, streamCount)) + return connection + } catch { + continuation.resume(throwing: error) + throw error + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift index 2ee9b7a0..637f096c 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift @@ -1,14 +1,69 @@ -import _ConnectionPoolModule +@testable import _ConnectionPoolModule +import DequeModule @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -struct MockPingPongBehavior: ConnectionKeepAliveBehavior { +final class MockPingPongBehavior: ConnectionKeepAliveBehavior { let keepAliveFrequency: Duration? - init(keepAliveFrequency: Duration?) { + let stateBox = NIOLockedValueBox(State()) + + struct State { + var runs = Deque<(Connection, CheckedContinuation)>() + + var waiter = Deque), Never>>() + } + + init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { self.keepAliveFrequency = keepAliveFrequency } - func runKeepAlive(for connection: MockConnection) async throws { - preconditionFailure() + func runKeepAlive(for connection: Connection) async throws { + precondition(self.keepAliveFrequency != nil) + + // we currently don't support cancellation when creating a connection + let success = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation) -> () in + let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(Connection, CheckedContinuation), Never>)? in + if let waiter = state.waiter.popFirst() { + return waiter + } else { + state.runs.append((connection, checkedContinuation)) + return nil + } + } + + if let waiter { + waiter.resume(returning: (connection, checkedContinuation)) + } + } + + precondition(success) + } + + @discardableResult + func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { + let (connection, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(Connection, CheckedContinuation), Never>) in + let run = self.stateBox.withLockedValue { state -> (Connection, CheckedContinuation)? in + if let run = state.runs.popFirst() { + return run + } else { + state.waiter.append(continuation) + return nil + } + } + + if let run { + continuation.resume(returning: run) + } + } + + do { + let success = try await closure(connection) + + continuation.resume(returning: success) + return connection + } catch { + continuation.resume(throwing: error) + throw error + } } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index 7751837e..bc4c2c4b 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -257,7 +257,7 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) - XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1)) + XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1, runningKeepAlive: false)) XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) } From c41f7e217e09c51a4453019b2875ecb82b69df3d Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 10 Nov 2023 12:33:30 -0600 Subject: [PATCH 210/292] Update README.md --- README.md | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 6f289673..ef1dc4ec 100644 --- a/README.md +++ b/README.md @@ -7,22 +7,22 @@

- Documentation + Documentation - MIT License + MIT License - Continuous Integration + Continuous Integration - Swift 5.7 - 5.9 + Swift 5.7 + - SSWG Incubation Level: Graduated + SSWG Incubation Level: Graduated

-
+ 🐘 Non-blocking, event-driven Swift client for PostgreSQL built on [SwiftNIO]. Features: @@ -190,9 +190,7 @@ Please see [SECURITY.md] for details on the security process. [`PostgresRowSequence`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrowsequence [`PostgresDecodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresdecodable [`PostgresEncodable`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresencodable - -[PostgresKit]: https://github.com/vapor/postgres-kit - [SwiftNIO]: https://github.com/apple/swift-nio +[PostgresKit]: https://github.com/vapor/postgres-kit [SwiftLog]: https://github.com/apple/swift-log [`Logger`]: https://apple.github.io/swift-log/docs/current/Logging/Structs/Logger.html From f0bfba793eb626cda98e456a7f1f2c1ef13a983a Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 10 Nov 2023 12:34:36 -0600 Subject: [PATCH 211/292] Temporarily disable nightly/main CI --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cc34ddcd..fe4aa185 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - swift:5.8-jammy - swift:5.9-jammy - swiftlang/swift:nightly-5.10-jammy - - swiftlang/swift:nightly-main-jammy + #- swiftlang/swift:nightly-main-jammy include: - swift-image: swift:5.9-jammy code-coverage: true From d5d16e3230cc1d86dde3fd9e8266422d27a440b6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Sun, 12 Nov 2023 12:17:09 +0100 Subject: [PATCH 212/292] Test cancel connection request (#439) --- .../ConnectionPoolTests.swift | 60 +++++++++- .../Utils/Waiter.swift | 109 ++++++++++++++++++ 2 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 Tests/ConnectionPoolModuleTests/Utils/Waiter.swift diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 57980711..4d4cac95 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -368,6 +368,64 @@ final class ConnectionPoolTests: XCTestCase { } } -} + func testCancelConnectionRequestWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + let leaseTask = Task { + _ = try await pool.leaseConnection() + } + + let connectionAttemptWaiter = Waiter(of: Void.self) + + taskGroup.addTask { + try await factory.nextConnectAttempt { connectionID in + connectionAttemptWaiter.yield(value: ()) + throw CancellationError() + } + } + + try await connectionAttemptWaiter.result + leaseTask.cancel() + + let taskResult = await leaseTask.result + switch taskResult { + case .success: + XCTFail("Expected task failure") + case .failure(let failure): + XCTAssertEqual(failure as? ConnectionPoolError, .requestCancelled) + } + + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift b/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift new file mode 100644 index 00000000..12cf90cc --- /dev/null +++ b/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift @@ -0,0 +1,109 @@ +import Atomics +@testable import _ConnectionPoolModule + +final class Waiter: Sendable { + struct State: Sendable { + + var result: Swift.Result? = nil + var continuations: [(Int, CheckedContinuation)] = [] + + } + + let waiterID = ManagedAtomic(0) + let stateBox: NIOLockedValueBox = NIOLockedValueBox(State()) + + init(of: Result.Type) {} + + enum GetAction { + case fail(any Error) + case succeed(Result) + case none + } + + var result: Result { + get async throws { + let waiterID = self.waiterID.loadThenWrappingIncrement(ordering: .relaxed) + + return try await withTaskCancellationHandler { + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let action = self.stateBox.withLockedValue { state -> GetAction in + if Task.isCancelled { + return .fail(CancellationError()) + } + + switch state.result { + case .none: + state.continuations.append((waiterID, continuation)) + return .none + + case .success(let result): + return .succeed(result) + + case .failure(let error): + return .fail(error) + } + } + + switch action { + case .fail(let error): + continuation.resume(throwing: error) + + case .succeed(let result): + continuation.resume(returning: result) + + case .none: + break + } + } + } onCancel: { + let cont = self.stateBox.withLockedValue { state -> CheckedContinuation? in + guard state.result == nil else { return nil } + + guard let contIndex = state.continuations.firstIndex(where: { $0.0 == waiterID }) else { + return nil + } + let (_, continuation) = state.continuations.remove(at: contIndex) + return continuation + } + + cont?.resume(throwing: CancellationError()) + } + } + } + + func yield(value: Result) { + let continuations = self.stateBox.withLockedValue { state in + guard state.result == nil else { + return [(Int, CheckedContinuation)]().lazy.map(\.1) + } + state.result = .success(value) + + let continuations = state.continuations + state.continuations = [] + + return continuations.lazy.map(\.1) + } + + for continuation in continuations { + continuation.resume(returning: value) + } + } + + func yield(error: any Error) { + let continuations = self.stateBox.withLockedValue { state in + guard state.result == nil else { + return [(Int, CheckedContinuation)]().lazy.map(\.1) + } + state.result = .failure(error) + + let continuations = state.continuations + state.continuations = [] + + return continuations.lazy.map(\.1) + } + + for continuation in continuations { + continuation.resume(throwing: error) + } + } +} From e1781633a8a843b8901ab8b71cdfdf80fad690af Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 13 Nov 2023 11:13:29 +0100 Subject: [PATCH 213/292] Add test to lease multiple connections at once (#440) - Add test to lease multiple connections at once - Rename `Waiter` to `Future` - Rename `Waiter.Result` to `Future.Success` --- .../ConnectionPoolTests.swift | 86 ++++++++++++++++++- .../Mocks/MockConnectionFactory.swift | 2 +- .../Utils/{Waiter.swift => Future.swift} | 25 +++--- 3 files changed, 99 insertions(+), 14 deletions(-) rename Tests/ConnectionPoolModuleTests/Utils/{Waiter.swift => Future.swift} (77%) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 4d4cac95..a4c2cde7 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -401,7 +401,7 @@ final class ConnectionPoolTests: XCTestCase { _ = try await pool.leaseConnection() } - let connectionAttemptWaiter = Waiter(of: Void.self) + let connectionAttemptWaiter = Future(of: Void.self) taskGroup.addTask { try await factory.nextConnectAttempt { connectionID in @@ -410,7 +410,7 @@ final class ConnectionPoolTests: XCTestCase { } } - try await connectionAttemptWaiter.result + try await connectionAttemptWaiter.success leaseTask.cancel() let taskResult = await leaseTask.result @@ -427,5 +427,87 @@ final class ConnectionPoolTests: XCTestCase { } } } + + func testLeasingMultipleConnectionsAtOnceWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // create 4 connection requests + let requests = (0..<4).map { ConnectionFuture(id: $0) } + + // lease 4 connections at once + pool.leaseConnections(requests) + var connections = [MockConnection]() + + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that we got 4 distinct connections + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 4) + + // release all 4 leased connections + for connection in connections { + pool.releaseConnection(connection) + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } } +struct ConnectionFuture: ConnectionRequestProtocol { + let id: Int + let future: Future + + init(id: Int) { + self.id = id + self.future = Future(of: MockConnection.self) + } + + func complete(with result: Result) { + switch result { + case .success(let success): + self.future.yield(value: success) + case .failure(let failure): + self.future.yield(error: failure) + } + } +} diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift index b0c94467..eec2e7c3 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift @@ -30,7 +30,7 @@ final class MockConnectionFactory where Clock.Duratio func makeConnection( id: Int, - for pool: ConnectionPool, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics, Clock> + for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> ) async throws -> ConnectionAndMetadata { // we currently don't support cancellation when creating a connection let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in diff --git a/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift b/Tests/ConnectionPoolModuleTests/Utils/Future.swift similarity index 77% rename from Tests/ConnectionPoolModuleTests/Utils/Waiter.swift rename to Tests/ConnectionPoolModuleTests/Utils/Future.swift index 12cf90cc..2bee3216 100644 --- a/Tests/ConnectionPoolModuleTests/Utils/Waiter.swift +++ b/Tests/ConnectionPoolModuleTests/Utils/Future.swift @@ -1,31 +1,34 @@ import Atomics @testable import _ConnectionPoolModule -final class Waiter: Sendable { +/// This is a `Future` type that shall make writing tests a bit simpler. I'm well aware, that this is a pattern +/// that should not be embraced with structured concurrency. However writing all tests in full structured +/// concurrency is an effort, that isn't worth the endgoals in my view. +final class Future: Sendable { struct State: Sendable { - var result: Swift.Result? = nil - var continuations: [(Int, CheckedContinuation)] = [] + var result: Swift.Result? = nil + var continuations: [(Int, CheckedContinuation)] = [] } let waiterID = ManagedAtomic(0) let stateBox: NIOLockedValueBox = NIOLockedValueBox(State()) - init(of: Result.Type) {} + init(of: Success.Type) {} enum GetAction { case fail(any Error) - case succeed(Result) + case succeed(Success) case none } - var result: Result { + var success: Success { get async throws { let waiterID = self.waiterID.loadThenWrappingIncrement(ordering: .relaxed) return try await withTaskCancellationHandler { - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in let action = self.stateBox.withLockedValue { state -> GetAction in if Task.isCancelled { return .fail(CancellationError()) @@ -56,7 +59,7 @@ final class Waiter: Sendable { } } } onCancel: { - let cont = self.stateBox.withLockedValue { state -> CheckedContinuation? in + let cont = self.stateBox.withLockedValue { state -> CheckedContinuation? in guard state.result == nil else { return nil } guard let contIndex = state.continuations.firstIndex(where: { $0.0 == waiterID }) else { @@ -71,10 +74,10 @@ final class Waiter: Sendable { } } - func yield(value: Result) { + func yield(value: Success) { let continuations = self.stateBox.withLockedValue { state in guard state.result == nil else { - return [(Int, CheckedContinuation)]().lazy.map(\.1) + return [(Int, CheckedContinuation)]().lazy.map(\.1) } state.result = .success(value) @@ -92,7 +95,7 @@ final class Waiter: Sendable { func yield(error: any Error) { let continuations = self.stateBox.withLockedValue { state in guard state.result == nil else { - return [(Int, CheckedContinuation)]().lazy.map(\.1) + return [(Int, CheckedContinuation)]().lazy.map(\.1) } state.result = .failure(error) From dc94503944f5f0a6b244efacd0ceb92d1e52cdb8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 14 Nov 2023 10:12:42 +0100 Subject: [PATCH 214/292] Add Test: Lease connection after shutdown has started fails (#441) --- .../ConnectionPoolTests.swift | 116 ++++++++++++++++++ .../Mocks/MockConnection.swift | 66 +++++++--- 2 files changed, 165 insertions(+), 17 deletions(-) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index a4c2cde7..d4388893 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -491,6 +491,122 @@ final class ConnectionPoolTests: XCTestCase { } } } + + func testLeasingConnectionAfterShutdownIsInvokedFails() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // shutdown + taskGroup.cancelAll() + + do { + _ = try await pool.leaseConnection() + XCTFail("Expected a failure") + } catch { + print("failed") + XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + } + + print("will close connections: \(factory.runningConnections)") + for connection in factory.runningConnections { + try await connection.signalToClose + connection.closeIfClosing() + } + } + } + + func testLeasingConnectionsAfterShutdownIsInvokedFails() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 4 + mutableConfig.maximumConnectionSoftLimit = 4 + mutableConfig.maximumConnectionHardLimit = 4 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 persisted connections + for _ in 0..<4 { + await factory.nextConnectAttempt { connectionID in + return 1 + } + } + + // shutdown + taskGroup.cancelAll() + + // create 4 connection requests + let requests = (0..<4).map { ConnectionFuture(id: $0) } + + // lease 4 connections at once + pool.leaseConnections(requests) + + for request in requests { + do { + _ = try await request.future.success + XCTFail("Expected a failure") + } catch { + XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + } + } + + for connection in factory.runningConnections { + try await connection.signalToClose + connection.closeIfClosing() + } + } + } } struct ConnectionFuture: ConnectionRequestProtocol { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift index 49bcc23a..f826ea04 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift @@ -2,38 +2,59 @@ import DequeModule @testable import _ConnectionPoolModule // Sendability enforced through the lock -final class MockConnection: PooledConnection, @unchecked Sendable { +final class MockConnection: PooledConnection, Sendable { typealias ID = Int let id: ID private enum State { - case running([@Sendable ((any Error)?) -> ()]) + case running([CheckedContinuation], [@Sendable ((any Error)?) -> ()]) case closing([@Sendable ((any Error)?) -> ()]) case closed } - private let lock = NIOLock() - private var _state = State.running([]) + private let lock: NIOLockedValueBox = NIOLockedValueBox(.running([], [])) init(id: Int) { self.id = id } + var signalToClose: Void { + get async throws { + try await withCheckedThrowingContinuation { continuation in + let runRightAway = self.lock.withLockedValue { state -> Bool in + switch state { + case .running(var continuations, let callbacks): + continuations.append(continuation) + state = .running(continuations, callbacks) + return false + + case .closing, .closed: + return true + } + } + + if runRightAway { + continuation.resume() + } + } + } + } + func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { - let enqueued = self.lock.withLock { () -> Bool in - switch self._state { + let enqueued = self.lock.withLockedValue { state -> Bool in + switch state { case .closed: return false - case .running(var callbacks): + case .running(let continuations, var callbacks): callbacks.append(closure) - self._state = .running(callbacks) + state = .running(continuations, callbacks) return true case .closing(var callbacks): callbacks.append(closure) - self._state = .closing(callbacks) + state = .closing(callbacks) return true } } @@ -44,25 +65,30 @@ final class MockConnection: PooledConnection, @unchecked Sendable { } func close() { - self.lock.withLock { - switch self._state { - case .running(let callbacks): - self._state = .closing(callbacks) + let continuations = self.lock.withLockedValue { state -> [CheckedContinuation] in + switch state { + case .running(let continuations, let callbacks): + state = .closing(callbacks) + return continuations case .closing, .closed: - break + return [] } } + + for continuation in continuations { + continuation.resume() + } } func closeIfClosing() { - let callbacks = self.lock.withLock { () -> [@Sendable ((any Error)?) -> ()] in - switch self._state { + let callbacks = self.lock.withLockedValue { state -> [@Sendable ((any Error)?) -> ()] in + switch state { case .running, .closed: return [] case .closing(let callbacks): - self._state = .closed + state = .closed return callbacks } } @@ -73,3 +99,9 @@ final class MockConnection: PooledConnection, @unchecked Sendable { } } +extension MockConnection: CustomStringConvertible { + var description: String { + let state = self.lock.withLockedValue { $0 } + return "MockConnection(id: \(self.id), state: \(state))" + } +} From 54f491c9b9a1d0a4f099d21a473b630bcc89d551 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 14 Nov 2023 15:02:23 +0100 Subject: [PATCH 215/292] Add support for multiple streams (#442) --- .../ConnectionPoolModule/ConnectionPool.swift | 6 +- .../PoolStateMachine+ConnectionGroup.swift | 48 +++++- .../PoolStateMachine+ConnectionState.swift | 47 ++++++ .../PoolStateMachine.swift | 46 +++++- .../ConnectionPoolTests.swift | 142 ++++++++++++++++++ 5 files changed, 280 insertions(+), 9 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index e9c9c4c9..ec865979 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -265,8 +265,10 @@ public final class ConnectionPool< } - public func connection(_ connection: Connection, didReceiveNewMaxStreamSetting: UInt16) { - + public func connectionReceivedNewMaxStreamSetting(_ connection: Connection, newMaxStreamSetting maxStreams: UInt16) { + self.modifyStateAndRunActions { state in + state.stateMachine.connectionReceivedNewMaxStreamSetting(connection.id, newMaxStreamSetting: maxStreams) + } } public func run() async { diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index fabc3009..0dbca86f 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -256,6 +256,50 @@ extension PoolStateMachine { return self.connections[index].timerScheduled(timer, cancelContinuation: cancelContinuation) } + // MARK: Changes at runtime + + @usableFromInline + struct NewMaxStreamInfo { + + @usableFromInline + var index: Int + + @usableFromInline + var newMaxStreams: UInt16 + + @usableFromInline + var oldMaxStreams: UInt16 + + @usableFromInline + var usedStreams: UInt16 + + @inlinable + init(index: Int, info: ConnectionState.NewMaxStreamInfo) { + self.index = index + self.newMaxStreams = info.newMaxStreams + self.oldMaxStreams = info.oldMaxStreams + self.usedStreams = info.usedStreams + } + } + + @inlinable + mutating func connectionReceivedNewMaxStreamSetting( + _ connectionID: ConnectionID, + newMaxStreamSetting maxStreams: UInt16 + ) -> NewMaxStreamInfo? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + return nil + } + + guard let info = self.connections[index].newMaxStreamSetting(maxStreams) else { + return nil + } + + self.stats.availableStreams += maxStreams - info.oldMaxStreams + + return NewMaxStreamInfo(index: index, info: info) + } + // MARK: Leasing and releasing /// Lease a connection, if an idle connection is available. @@ -424,9 +468,9 @@ extension PoolStateMachine { /// Closes the connection at the given index. @inlinable - mutating func closeConnectionIfIdle(at index: Int) -> CloseAction { + mutating func closeConnectionIfIdle(at index: Int) -> CloseAction? { guard let closeAction = self.connections[index].closeIfIdle() else { - preconditionFailure("Invalid state: \(self)") + return nil // apparently the connection isn't idle } self.stats.idle -= 1 diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index 94196a09..98755ff9 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -195,6 +195,53 @@ extension PoolStateMachine { } } + @usableFromInline + struct NewMaxStreamInfo { + @usableFromInline + var newMaxStreams: UInt16 + + @usableFromInline + var oldMaxStreams: UInt16 + + @usableFromInline + var usedStreams: UInt16 + + @inlinable + init(newMaxStreams: UInt16, oldMaxStreams: UInt16, usedStreams: UInt16) { + self.newMaxStreams = newMaxStreams + self.oldMaxStreams = oldMaxStreams + self.usedStreams = usedStreams + } + } + + @inlinable + mutating func newMaxStreamSetting(_ newMaxStreams: UInt16) -> NewMaxStreamInfo? { + switch self.state { + case .starting, .backingOff: + preconditionFailure("Invalid state: \(self.state)") + + case .idle(let connection, let oldMaxStreams, let keepAlive, idleTimer: let idleTimer): + self.state = .idle(connection, maxStreams: newMaxStreams, keepAlive: keepAlive, idleTimer: idleTimer) + return NewMaxStreamInfo( + newMaxStreams: newMaxStreams, + oldMaxStreams: oldMaxStreams, + usedStreams: keepAlive.usedStreams + ) + + case .leased(let connection, let usedStreams, let oldMaxStreams, let keepAlive): + self.state = .leased(connection, usedStreams: usedStreams, maxStreams: newMaxStreams, keepAlive: keepAlive) + return NewMaxStreamInfo( + newMaxStreams: newMaxStreams, + oldMaxStreams: oldMaxStreams, + usedStreams: usedStreams + keepAlive.usedStreams + ) + + case .closing, .closed: + return nil + } + } + + @inlinable mutating func parkConnection(scheduleKeepAliveTimer: Bool, scheduleIdleTimeoutTimer: Bool) -> Max2Sequence { var keepAliveTimer: ConnectionTimer? diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 4484e405..6671460a 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -262,6 +262,39 @@ struct PoolStateMachine< } } + @inlinable + mutating func connectionReceivedNewMaxStreamSetting( + _ connection: ConnectionID, + newMaxStreamSetting maxStreams: UInt16 + ) -> Action { + guard let info = self.connections.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: maxStreams) else { + return .none() + } + + let waitingRequests = self.requestQueue.count + + guard waitingRequests > 0 else { + return .none() + } + + // the only thing we can do if we receive a new max stream setting is check if the new stream + // setting is higher and then dequeue some waiting requests + + guard info.newMaxStreams > info.oldMaxStreams && info.newMaxStreams > info.usedStreams else { + return .none() + } + + let leaseStreams = min(info.newMaxStreams - info.oldMaxStreams, info.newMaxStreams - info.usedStreams, UInt16(clamping: waitingRequests)) + let requests = self.requestQueue.pop(max: leaseStreams) + precondition(Int(leaseStreams) == requests.count) + let leaseResult = self.connections.leaseConnection(at: info.index, streams: leaseStreams) + + return .init( + request: .leaseConnection(requests, leaseResult.connection), + connection: .cancelTimers(.init(leaseResult.timersToCancel)) + ) + } + @inlinable mutating func timerScheduled(_ timer: Timer, cancelContinuation: TimerCancellationToken) -> TimerCancellationToken? { self.connections.timerScheduled(timer.underlying, cancelContinuation: cancelContinuation) @@ -445,11 +478,14 @@ struct PoolStateMachine< } case .overflow: - let closeAction = self.connections.closeConnectionIfIdle(at: index) - return .init( - request: .none, - connection: .closeConnection(closeAction.connection, closeAction.timersToCancel) - ) + if let closeAction = self.connections.closeConnectionIfIdle(at: index) { + return .init( + request: .none, + connection: .closeConnection(closeAction.connection, closeAction.timersToCancel) + ) + } else { + return .none() + } } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index d4388893..0ff2bdf7 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -607,6 +607,148 @@ final class ConnectionPoolTests: XCTestCase { } } } + + func testLeasingMultipleStreamsFromOneConnectionWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 10 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 connection requests + let requests = (0..<10).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + var connections = [MockConnection]() + + await factory.nextConnectAttempt { connectionID in + return 10 + } + + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that all requests got the same connection + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + + // release all 10 leased streams + for connection in connections { + pool.releaseConnection(connection) + } + + for _ in 0..<9 { + _ = try? await factory.nextConnectAttempt { connectionID in + throw CancellationError() + } + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + + func testIncreasingAvailableStreamsWorks() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(30) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + mutableConfig.idleTimeout = .seconds(10) + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionFuture.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + // create 4 connection requests + var requests = (0..<21).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + var connections = [MockConnection]() + + await factory.nextConnectAttempt { connectionID in + return 1 + } + + let connection = try await requests.first!.future.success + connections.append(connection) + requests.removeFirst() + + pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) + + for (index, request) in requests.enumerated() { + let connection = try await request.future.success + connections.append(connection) + } + + // Ensure that all requests got the same connection + XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + + requests = (22..<42).map { ConnectionFuture(id: $0) } + pool.leaseConnections(requests) + + // release all 21 leased streams in a single call + pool.releaseConnection(connection, streams: 21) + + // ensure all 20 new requests got fulfilled + for request in requests { + let connection = try await request.future.success + connections.append(connection) + } + + // release all 20 leased streams one by one + for _ in requests { + pool.releaseConnection(connection, streams: 1) + } + + // shutdown + taskGroup.cancelAll() + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } } struct ConnectionFuture: ConnectionRequestProtocol { From e60e49507411fbf187fcf9f74a4596d68f3651c9 Mon Sep 17 00:00:00 2001 From: Timo <38291523+lovetodream@users.noreply.github.com> Date: Tue, 12 Dec 2023 16:28:16 +0100 Subject: [PATCH 216/292] Fix crash in PoolStateMachine+ConnectionGroup when closing connection while keepAlive is running (#444) Fixes #443. Co-authored-by: Gwynne Raskind Co-authored-by: Fabian Fett --- .github/workflows/test.yml | 16 ++- .../ConnectionPoolModule/ConnectionPool.swift | 2 +- .../PoolStateMachine+ConnectionGroup.swift | 24 ++++ .../PoolStateMachine+ConnectionState.swift | 5 + .../PoolStateMachine.swift | 9 ++ .../ConnectionPoolTests.swift | 86 ++++++++++++++ ...oolStateMachine+ConnectionGroupTests.swift | 31 +++++ .../PoolStateMachineTests.swift | 111 ++++++++++++++++++ 8 files changed, 278 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fe4aa185..3d1f44a4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - swift:5.8-jammy - swift:5.9-jammy - swiftlang/swift:nightly-5.10-jammy - #- swiftlang/swift:nightly-main-jammy + - swiftlang/swift:nightly-main-jammy include: - swift-image: swift:5.9-jammy code-coverage: true @@ -133,7 +133,7 @@ jobs: matrix: postgres-formula: # Only test one version on macOS, let Linux do the rest - - postgresql@15 + - postgresql@16 postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 @@ -157,10 +157,16 @@ jobs: - name: Install Postgres, setup DB and auth, and wait for server start run: | export PATH="$(brew --prefix)/opt/${POSTGRES_FORMULA}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - (brew unlink postgresql || true) && brew install "${POSTGRES_FORMULA}" && brew link --force "${POSTGRES_FORMULA}" + # ** BEGIN ** Work around bug in both Homebrew and GHA + (brew upgrade python@3.11 || true) && (brew link --force --overwrite python@3.11 || true) + (brew upgrade python@3.12 || true) && (brew link --force --overwrite python@3.12 || true) + brew upgrade + # ** END ** Work around bug in both Homebrew and GHA + brew install --overwrite "${POSTGRES_FORMULA}" + brew link --overwrite --force "${POSTGRES_FORMULA}" initdb --locale=C --auth-host "${POSTGRES_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait - timeout-minutes: 2 + timeout-minutes: 15 - name: Checkout code uses: actions/checkout@v4 - name: Run all tests @@ -183,7 +189,7 @@ jobs: gh-codeql: runs-on: ubuntu-latest - container: swift:5.8-jammy # CodeQL currently broken with 5.9 + container: swift:5.9-jammy permissions: { actions: write, contents: read, security-events: write } steps: - name: Check out code diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index ec865979..c20fa59e 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -481,7 +481,7 @@ public final class ConnectionPool< self.observabilityDelegate.keepAliveFailed(id: connection.id, error: error) self.modifyStateAndRunActions { state in - state.stateMachine.connectionClosed(connection) + state.stateMachine.connectionKeepAliveFailed(connection.id) } } } diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 0dbca86f..833365fa 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -449,6 +449,30 @@ extension PoolStateMachine { return (index, context) } + @inlinable + mutating func keepAliveFailed(_ connectionID: Connection.ID) -> CloseAction? { + guard let index = self.connections.firstIndex(where: { $0.id == connectionID }) else { + // Connection has already been closed + return nil + } + + guard let closeAction = self.connections[index].keepAliveFailed() else { + return nil + } + + self.stats.idle -= 1 + self.stats.closing += 1 + self.stats.runningKeepAlive -= closeAction.runningKeepAlive ? 1 : 0 + self.stats.availableStreams -= closeAction.maxStreams - closeAction.usedStreams + + // force unwrapping the connection is fine, because a close action due to failed + // keepAlive cannot happen without a connection + return CloseAction( + connection: closeAction.connection!, + timersToCancel: closeAction.cancelTimers + ) + } + // MARK: Connection close/removal @usableFromInline diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index 98755ff9..2fb68a2d 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -455,6 +455,11 @@ extension PoolStateMachine { } } + @inlinable + mutating func keepAliveFailed() -> CloseAction? { + return self.close() + } + @inlinable mutating func timerScheduled( _ timer: ConnectionTimer, diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 6671460a..3b996033 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -374,6 +374,15 @@ struct PoolStateMachine< return self.handleAvailableConnection(index: index, availableContext: context) } + @inlinable + mutating func connectionKeepAliveFailed(_ connectionID: ConnectionID) -> Action { + guard let closeAction = self.connections.keepAliveFailed(connectionID) else { + return .none() + } + + return .init(request: .none, connection: .closeConnection(closeAction.connection, closeAction.timersToCancel)) + } + @inlinable mutating func connectionIdleTimerTriggered(_ connectionID: ConnectionID) -> Action { precondition(self.requestQueue.isEmpty) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 0ff2bdf7..ba3c6a3f 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -300,6 +300,92 @@ final class ConnectionPoolTests: XCTestCase { } } + func testKeepAliveOnClose() async throws { + let clock = MockClock() + let factory = MockConnectionFactory() + let keepAliveDuration = Duration.seconds(20) + let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self) + + var mutableConfig = ConnectionPoolConfiguration() + mutableConfig.minimumConnectionCount = 0 + mutableConfig.maximumConnectionSoftLimit = 1 + mutableConfig.maximumConnectionHardLimit = 1 + let config = mutableConfig + + let pool = ConnectionPool( + configuration: config, + idGenerator: ConnectionIDGenerator(), + requestType: ConnectionRequest.self, + keepAliveBehavior: keepAlive, + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await pool.run() + } + + async let lease1ConnectionAsync = pool.leaseConnection() + + let connection = await factory.nextConnectAttempt { connectionID in + return 1 + } + + let lease1Connection = try await lease1ConnectionAsync + XCTAssert(connection === lease1Connection) + + pool.releaseConnection(lease1Connection) + + // keep alive 1 + + // validate that a keep alive timer and an idle timeout timer is scheduled + var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] + let deadline1 = await clock.nextTimerScheduled() + print(deadline1) + XCTAssertNotNil(expectedInstants.remove(deadline1)) + let deadline2 = await clock.nextTimerScheduled() + print(deadline2) + XCTAssertNotNil(expectedInstants.remove(deadline2)) + XCTAssert(expectedInstants.isEmpty) + + // move clock forward to keep alive + let newTime = clock.now.advanced(by: keepAliveDuration) + clock.advance(to: newTime) + + await keepAlive.nextKeepAlive { keepAliveConnection in + XCTAssertTrue(keepAliveConnection === lease1Connection) + return true + } + + // keep alive 2 + let deadline3 = await clock.nextTimerScheduled() + XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + clock.advance(to: clock.now.advanced(by: keepAliveDuration)) + + let failingKeepAliveDidRun = ManagedAtomic(false) + // the following keep alive should not cause a crash + _ = try? await keepAlive.nextKeepAlive { keepAliveConnection in + defer { + XCTAssertFalse(failingKeepAliveDidRun + .compareExchange(expected: false, desired: true, ordering: .relaxed).original) + } + XCTAssertTrue(keepAliveConnection === lease1Connection) + keepAliveConnection.close() + throw CancellationError() // any error + } // will fail and it's expected + XCTAssertTrue(failingKeepAliveDidRun.load(ordering: .relaxed)) + + taskGroup.cancelAll() + + for connection in factory.runningConnections { + connection.closeIfClosing() + } + } + } + func testKeepAliveWorksRacesAgainstShutdown() async throws { let clock = MockClock() let factory = MockConnectionFactory() diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index ac0f96f4..6b8d6c6e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -293,4 +293,35 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { XCTAssertEqual(afterPingIdleContext.use, .persisted) XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) } + + func testKeepAliveShouldNotIndicateCloseConnectionAfterClosed() { + var connections = TestPoolStateMachine.ConnectionGroup( + generator: self.idGenerator, + minimumConcurrentConnections: 0, + maximumConcurrentConnectionSoftLimit: 2, + maximumConcurrentConnectionHardLimit: 2, + keepAlive: true, + keepAliveReducesAvailableStreams: true + ) + + guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { return XCTFail("Expected to have a request here") } + + let newConnection = MockConnection(id: firstRequest.connectionID) + let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) + XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) + XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + _ = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) + let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) + let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) + XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) + XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + + _ = connections.closeConnectionIfIdle(newConnection.id) + guard connections.keepAliveFailed(newConnection.id) == nil else { + return XCTFail("Expected keepAliveFailed not to cause close again") + } + XCTAssertEqual(connections.stats, .init(closing: 1)) + } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index a19d2326..f5ada14f 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -266,4 +266,115 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(releaseRequest1.connection, .none) } + func testKeepAliveOnClosingConnection() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 0 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = .seconds(2) + configuration.idleTimeoutDuration = .seconds(4) + + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // don't refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 0) + + // request connection while none exists + let request1 = MockRequest() + let leaseRequest1 = stateMachine.leaseConnection(request1) + XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) + XCTAssertEqual(leaseRequest1.request, .none) + + // make connection 1 + let connection1 = MockConnection(id: 0) + let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) + XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) + XCTAssertEqual(createdAction1.connection, .none) + _ = stateMachine.releaseConnection(connection1, streams: 1) + + // trigger keep alive + let keepAliveAction1 = stateMachine.connectionKeepAliveTimerTriggered(connection1.id) + XCTAssertEqual(keepAliveAction1.connection, .runKeepAlive(connection1, nil)) + + // fail keep alive and cause closed + let keepAliveFailed1 = stateMachine.connectionKeepAliveFailed(connection1.id) + XCTAssertEqual(keepAliveFailed1.connection, .closeConnection(connection1, [])) + connection1.closeIfClosing() + + // request connection while none exists anymore + let request2 = MockRequest() + let leaseRequest2 = stateMachine.leaseConnection(request2) + XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) + XCTAssertEqual(leaseRequest2.request, .none) + + // make connection 2 + let connection2 = MockConnection(id: 1) + let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) + XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) + XCTAssertEqual(createdAction2.connection, .none) + _ = stateMachine.releaseConnection(connection2, streams: 1) + + // trigger keep alive while connection is still open + let keepAliveAction2 = stateMachine.connectionKeepAliveTimerTriggered(connection2.id) + XCTAssertEqual(keepAliveAction2.connection, .runKeepAlive(connection2, nil)) + + // close connection in the middle of keep alive + connection2.close() + connection2.closeIfClosing() + + // fail keep alive and cause closed + let keepAliveFailed2 = stateMachine.connectionKeepAliveFailed(connection2.id) + XCTAssertEqual(keepAliveFailed2.connection, .closeConnection(connection2, [])) + } + + func testConnectionIsEstablishedAfterFailedKeepAliveIfNotEnoughConnectionsLeft() { + var configuration = PoolConfiguration() + configuration.minimumConnectionCount = 1 + configuration.maximumConnectionSoftLimit = 2 + configuration.maximumConnectionHardLimit = 2 + configuration.keepAliveDuration = .seconds(2) + configuration.idleTimeoutDuration = .seconds(4) + + + var stateMachine = TestPoolStateMachine( + configuration: configuration, + generator: .init(), + timerCancellationTokenType: MockTimerCancellationToken.self + ) + + // refill pool + let requests = stateMachine.refillConnections() + XCTAssertEqual(requests.count, 1) + + // one connection should exist + let request = MockRequest() + let leaseRequest = stateMachine.leaseConnection(request) + XCTAssertEqual(leaseRequest.connection, .none) + XCTAssertEqual(leaseRequest.request, .none) + + // make connection 1 + let connection = MockConnection(id: 0) + let createdAction = stateMachine.connectionEstablished(connection, maxStreams: 1) + XCTAssertEqual(createdAction.request, .leaseConnection(.init(element: request), connection)) + XCTAssertEqual(createdAction.connection, .none) + _ = stateMachine.releaseConnection(connection, streams: 1) + + // trigger keep alive + let keepAliveAction = stateMachine.connectionKeepAliveTimerTriggered(connection.id) + XCTAssertEqual(keepAliveAction.connection, .runKeepAlive(connection, nil)) + + // fail keep alive, cause closed and make new connection + let keepAliveFailed = stateMachine.connectionKeepAliveFailed(connection.id) + XCTAssertEqual(keepAliveFailed.connection, .closeConnection(connection, [])) + let connectionClosed = stateMachine.connectionClosed(connection) + XCTAssertEqual(connectionClosed.connection, .makeConnection(.init(connectionID: 1), [])) + connection.closeIfClosing() + } + } From fa3137d39bca84843739db1c5a3db2d7f4ae65e6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 12 Dec 2023 17:01:12 +0100 Subject: [PATCH 217/292] Support additional connection parameters (#361) --- .../PostgresConnection+Configuration.swift | 7 +++- .../ConnectionStateMachine.swift | 34 ++++++++++++--- .../New/PostgresChannelHandler.swift | 2 +- .../New/PostgresFrontendMessageEncoder.swift | 9 +++- .../PSQLFrontendMessageDecoder.swift | 11 +++-- .../Extensions/PostgresFrontendMessage.swift | 27 ++++++++++-- .../New/Messages/StartupTests.swift | 41 ++++++++++++++++++- .../New/PostgresChannelHandlerTests.swift | 9 ++-- .../New/PostgresConnectionTests.swift | 2 +- 9 files changed, 117 insertions(+), 25 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index 22c59d8a..dd0f5404 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -85,7 +85,11 @@ extension PostgresConnection { /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). public var requireBackendKeyData: Bool - + + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] + /// Create an options structure with default values. /// /// Most users should not need to adjust the defaults. @@ -93,6 +97,7 @@ extension PostgresConnection { self.connectTimeout = .seconds(10) self.tlsServerName = nil self.requireBackendKeyData = true + self.additionalStartupParameters = [] } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9cde0cf3..d7a609a6 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1113,11 +1113,19 @@ struct SendPrepareStatement { let query: String } -struct AuthContext: Equatable, CustomDebugStringConvertible { - let username: String - let password: String? - let database: String? - +struct AuthContext: CustomDebugStringConvertible { + var username: String + var password: String? + var database: String? + var additionalParameters: [(String, String)] + + init(username: String, password: String? = nil, database: String? = nil, additionalParameters: [(String, String)] = []) { + self.username = username + self.password = password + self.database = database + self.additionalParameters = additionalParameters + } + var debugDescription: String { """ AuthContext(username: \(String(reflecting: self.username)), \ @@ -1127,6 +1135,22 @@ struct AuthContext: Equatable, CustomDebugStringConvertible { } } +extension AuthContext: Equatable { + static func ==(lhs: Self, rhs: Self) -> Bool { + guard lhs.username == rhs.username + && lhs.password == rhs.password + && lhs.database == rhs.database + && lhs.additionalParameters.count == rhs.additionalParameters.count + else { + return false + } + + return lhs.additionalParameters.elementsEqual(rhs.additionalParameters) { lhs, rhs in + lhs.0 == rhs.0 && lhs.1 == rhs.1 + } + } +} + enum PasswordAuthencationMode: Equatable { case cleartext case md5(salt: UInt32) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 9d0ef2a5..54ae0fc9 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - self.encoder.startup(user: authContext.username, database: authContext.database) + self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSSLRequest: self.encoder.ssl() diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index e98ab1f1..97805418 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -25,7 +25,7 @@ struct PostgresFrontendMessageEncoder { self.buffer = buffer } - mutating func startup(user: String, database: String?) { + mutating func startup(user: String, database: String?, options: [(String, String)]) { self.clearIfNeeded() self.buffer.psqlLengthPrefixed { buffer in buffer.writeInteger(Self.startupVersionThree) @@ -37,6 +37,13 @@ struct PostgresFrontendMessageEncoder { buffer.writeNullTerminatedString(database) } + // we don't send replication parameters, as the default is false and this is what we + // need for a client + for (key, value) in options { + buffer.writeNullTerminatedString(key) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(UInt8(0)) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 46c043b1..55ccd0a9 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -39,8 +39,8 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { case 196608: var user: String? var database: String? - var options: String? - + var options = [(String, String)]() + while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex { let value = messageSlice.readNullTerminatedString() @@ -51,11 +51,10 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { case "database": database = value - case "options": - options = value - default: - break + if let value = value { + options.append((name, value)) + } } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 010667dc..2532959a 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -103,7 +103,7 @@ enum PostgresFrontendMessage: Equatable { static let requestCode: Int32 = 80877103 } - struct Startup: Hashable { + struct Startup: Equatable { static let versionThree: Int32 = 0x00_03_00_00 /// Creates a `Startup` with "3.0" as the protocol version. @@ -119,7 +119,7 @@ enum PostgresFrontendMessage: Equatable { /// The protocol version number is followed by one or more pairs of parameter /// name and value strings. A zero byte is required as a terminator after /// the last name/value pair. `user` is required, others are optional. - struct Parameters: Hashable { + struct Parameters: Equatable { enum Replication { case `true` case `false` @@ -136,12 +136,33 @@ enum PostgresFrontendMessage: Equatable { /// of setting individual run-time parameters.) Spaces within this string are /// considered to separate arguments, unless escaped with a /// backslash (\); write \\ to represent a literal backslash. - var options: String? + var options: [(String, String)] /// Used to connect in streaming replication mode, where a small set of /// replication commands can be issued instead of SQL statements. Value /// can be true, false, or database, and the default is false. var replication: Replication + + static func ==(lhs: Self, rhs: Self) -> Bool { + guard lhs.user == rhs.user + && lhs.database == rhs.database + && lhs.replication == rhs.replication + && lhs.options.count == rhs.options.count + else { + return false + } + + var lhsIterator = lhs.options.makeIterator() + var rhsIterator = rhs.options.makeIterator() + + while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() { + guard lhsNext.0 == rhsNext.0 && lhsNext.1 == rhsNext.1 else { + return false + } + } + return true + } + } var parameters: Parameters diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 39e9bb42..5af3bf34 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -11,7 +11,7 @@ class StartupTests: XCTestCase { let user = "test" let database = "abc123" - encoder.startup(user: user, database: database) + encoder.startup(user: user, database: database, options: []) byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) @@ -32,7 +32,7 @@ class StartupTests: XCTestCase { let user = "test" - encoder.startup(user: user, database: nil) + encoder.startup(user: user, database: nil, options: []) byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) @@ -44,4 +44,41 @@ class StartupTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, 0) } + + func testStartupMessageWithAdditionalOptions() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + var byteBuffer = ByteBuffer() + + let user = "test" + let database = "abc123" + + encoder.startup(user: user, database: database, options: [("some", "options")]) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} + +extension PostgresFrontendMessage.Startup.Parameters.Replication { + var stringValue: String { + switch self { + case .true: + return "true" + case .false: + return "false" + case .database: + return "replication" + } + } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index b81d0899..dfdcc53e 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -37,9 +37,8 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(startup.parameters.user, config.username) XCTAssertEqual(startup.parameters.database, config.database) - XCTAssertEqual(startup.parameters.options, nil) - XCTAssertEqual(startup.parameters.replication, .false) - + XCTAssert(startup.parameters.options.isEmpty) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.ok))) XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))) XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.readyForQuery(.idle))) @@ -209,7 +208,7 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(startup.parameters.user, config.username) XCTAssertEqual(startup.parameters.database, config.database) - XCTAssertEqual(startup.parameters.options, nil) + XCTAssert(startup.parameters.options.isEmpty) XCTAssertEqual(startup.parameters.replication, .false) var buffer = ByteBuffer() @@ -282,7 +281,7 @@ extension AuthContext { PostgresFrontendMessage.Startup.Parameters( user: self.username, database: self.database, - options: nil, + options: self.additionalParameters, replication: .false ) } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 3b1a8ca9..82baf914 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -602,7 +602,7 @@ class PostgresConnectionTests: XCTestCase { async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false)))) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) From ea0800d12bbf70a3968b6ccd0cf17bb5d861530f Mon Sep 17 00:00:00 2001 From: Timo <38291523+lovetodream@users.noreply.github.com> Date: Tue, 9 Jan 2024 12:53:32 +0100 Subject: [PATCH 218/292] Fix Availability for DiscardingTaskGroup on watchOS (#448) --- Sources/ConnectionPoolModule/ConnectionPool.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index c20fa59e..9f25e82c 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -591,7 +591,7 @@ protocol TaskGroupProtocol { } #if swift(>=5.8) && os(Linux) || swift(>=5.9) -@available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 9.0, *) +@available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) extension DiscardingTaskGroup: TaskGroupProtocol {} #endif From 6ce96ab041ee055d6da97717fafa742b0f5915c9 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 30 Jan 2024 04:13:09 -0600 Subject: [PATCH 219/292] Add `Sendable` conformance to `PostgresEncodingContext` (#450) --- Sources/PostgresNIO/New/PostgresCodable.swift | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresCodable.swift b/Sources/PostgresNIO/New/PostgresCodable.swift index 71c689bf..fd82c8ea 100644 --- a/Sources/PostgresNIO/New/PostgresCodable.swift +++ b/Sources/PostgresNIO/New/PostgresCodable.swift @@ -166,11 +166,10 @@ extension PostgresDynamicTypeEncodable { /// A context that is passed to Swift objects that are encoded into the Postgres wire format. Used /// to pass further information to the encoding method. -public struct PostgresEncodingContext { +public struct PostgresEncodingContext: Sendable { /// A ``PostgresJSONEncoder`` used to encode the object to json. public var jsonEncoder: JSONEncoder - /// Creates a ``PostgresEncodingContext`` with the given ``PostgresJSONEncoder``. In case you want /// to use the a ``PostgresEncodingContext`` with an unconfigured Foundation `JSONEncoder` /// you can use the ``default`` context instead. From e9b90b2189b6c64d41522d87616b04f6d978bb06 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 30 Jan 2024 08:28:08 -0600 Subject: [PATCH 220/292] Fix mishandling of SASL attribute parsing (#451) --- .../SASLAuthentication+SCRAM-SHA256.swift | 7 +++--- .../AuthenticationStateMachineTests.swift | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index f2fd8e1a..ac1d9ead 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -209,14 +209,13 @@ fileprivate struct SCRAMMessageParser { } static func parse(raw: [UInt8], isGS2Header: Bool = false) -> [SCRAMAttribute]? { - // There are two ways to implement this parse: // 1. All-at-once: Split on comma, split each on equals, validate // each results in a valid attribute. // 2. Sequential: State machine lookahead parse. // The former is simpler. The latter provides better validation. - let likelyAttributeSets = raw.split(separator: .comma, maxSplits: isGS2Header ? 3 : Int.max, omittingEmptySubsequences: false) - let likelyAttributePairs = likelyAttributeSets.map { $0.split(separator: .equals, maxSplits: 2, omittingEmptySubsequences: false) } + let likelyAttributeSets = raw.split(separator: .comma, maxSplits: isGS2Header ? 2 : Int.max, omittingEmptySubsequences: false) + let likelyAttributePairs = likelyAttributeSets.map { $0.split(separator: .equals, maxSplits: 1, omittingEmptySubsequences: false) } let results = likelyAttributePairs.map { parseAttributePair(name: Array($0[0]), value: $0.dropFirst().first.map { Array($0) } ?? [], isGS2Header: isGS2Header) } let validResults = results.compactMap { $0 } @@ -369,7 +368,7 @@ internal struct SHA256_PLUS: SASLAuthenticationMechanism { } // enum SCRAM } // enum SASLMechanism -/// Common impplementation of SCRAM-SHA-256 and SCRAM-SHA-256-PLUS +/// Common implementation of SCRAM-SHA-256 and SCRAM-SHA-256-PLUS fileprivate final class SASLMechanism_SCRAM_SHA256_Common { /// Initialized with initial client state diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index b06b69ab..df881f90 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -45,6 +45,30 @@ class AuthenticationStateMachineTests: XCTestCase { XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) } + func testAuthenticateSCRAMSHA256WithAtypicalEncoding() { + let authContext = AuthContext(username: "test", password: "abc123", database: "test") + var state = ConnectionStateMachine(requireBackendKeyData: true) + XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + + let saslResponse = state.authenticationMessageReceived(.sasl(names: ["SCRAM-SHA-256"])) + guard case .sendSaslInitialResponse(name: let name, initialResponse: let responseData) = saslResponse else { + return XCTFail("\(saslResponse) is not .sendSaslInitialResponse") + } + let responseString = String(decoding: responseData, as: UTF8.self) + XCTAssertEqual(name, "SCRAM-SHA-256") + XCTAssert(responseString.starts(with: "n,,n=test,r=")) + + let saslContinueResponse = state.authenticationMessageReceived(.saslContinue(data: .init(bytes: + "r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,s=ijgUVaWgCDLRJyF963BKNA==,i=4096".utf8 + ))) + guard case .sendSaslResponse(let responseData2) = saslContinueResponse else { + return XCTFail("\(saslContinueResponse) is not .sendSaslResponse") + } + let response2String = String(decoding: responseData2, as: UTF8.self) + XCTAssertEqual(response2String.prefix(76), "c=biws,r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,p=") + } + func testAuthenticationFailure() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) From 69ccfdf4c80144d845e3b439961b7ec6cd7ae33f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 31 Jan 2024 16:23:36 +0100 Subject: [PATCH 221/292] Be resilient about a read after connection closed (#452) fixes #449 --- .../ConnectionStateMachine.swift | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index d7a609a6..8c3252de 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -624,21 +624,19 @@ struct ConnectionStateMachine { mutating func readEventCaught() -> ConnectionAction { switch self.state { case .initialized: - preconditionFailure("Received a read event on a connection that was never opened.") - case .sslRequestSent: - return .read - case .sslNegotiated: - return .read - case .sslHandlerAdded: - return .read - case .waitingToStartAuthentication: - return .read - case .authenticating: - return .read - case .authenticated: - return .read - case .readyForQuery: + preconditionFailure("Invalid state: \(self.state). Read event before connection established?") + + case .sslRequestSent, + .sslNegotiated, + .sslHandlerAdded, + .waitingToStartAuthentication, + .authenticating, + .authenticated, + .readyForQuery, + .closing: + // all states in which we definitely want to make further forward progress... return .read + case .extendedQuery(var extendedQuery, let connectionContext): self.state = .modifying // avoid CoW let action = extendedQuery.readEventCaught() @@ -651,12 +649,15 @@ struct ConnectionStateMachine { self.state = .closeCommand(closeState, connectionContext) return self.modify(with: action) - case .closing: - return .read case .closed: - preconditionFailure("How can we receive a read, if the connection is closed") + // Generally we shouldn't see this event (read after connection closed?!). + // But truth is, adopters run into this, again and again. So preconditioning here leads + // to unnecessary crashes. So let's be resilient and just make more forward progress. + // If we really care, we probably need to dive deep into PostgresNIO and SwiftNIO. + return .read + case .modifying: - preconditionFailure("Invalid state") + preconditionFailure("Invalid state: \(self.state)") } } From 6433f6d87b0fa7daf9aaeb742bd3c8fd1f16ec26 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 17:09:16 +0100 Subject: [PATCH 222/292] Fix warnings (#454) --- Sources/PostgresNIO/New/PSQLRowStream.swift | 1 + .../SASLAuthentication+SCRAM-SHA256.swift | 111 +++++++++--------- .../ConnectionPoolTests.swift | 2 +- .../ConnectionAction+TestUtils.swift | 12 +- .../New/Messages/DataRowTests.swift | 2 +- 5 files changed, 65 insertions(+), 63 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b3dfea30..0255e462 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -86,6 +86,7 @@ final class PSQLRowStream: @unchecked Sendable { elementType: DataRow.self, failureType: Error.self, backPressureStrategy: AdaptiveRowBuffer(), + finishOnDeinit: false, delegate: self ) diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index ac1d9ead..2a717b6b 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -1,13 +1,10 @@ import Crypto import Foundation -extension UInt8: ExpressibleByUnicodeScalarLiteral { +extension UInt8 { fileprivate static var NUL: UInt8 { return 0x00 /* yeah, just U+0000 man */ } fileprivate static var comma: UInt8 { return 0x2c /* .init(ascii: ",") */ } fileprivate static var equals: UInt8 { return 0x3d /* .init(ascii: "=") */ } - public init(unicodeScalarLiteral value: Unicode.Scalar) { - self.init(ascii: value) - } } fileprivate extension String { @@ -87,7 +84,7 @@ fileprivate extension Array where Element == UInt8 { */ var isValidScramValue: Bool { // TODO: FInd a better way than doing a whole construction of String... - return self.count > 0 && !(String(bytes: self, encoding: .utf8)?.contains(",") ?? true) + return self.count > 0 && !(String(decoding: self, as: Unicode.UTF8.self).contains(",")) } } @@ -171,40 +168,40 @@ fileprivate struct SCRAMMessageParser { static func parseAttributePair(name: [UInt8], value: [UInt8], isGS2Header: Bool = false) -> SCRAMAttribute? { guard name.count == 1 || isGS2Header else { return nil } switch name.first { - case "m" where !isGS2Header: return .m(value) - case "r" where !isGS2Header: return String(printableAscii: value).map { .r($0) } - case "c" where !isGS2Header: - guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil } - guard (1...3).contains(parsedAttrs.count) else { return nil } - switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) { - case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident) - case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data)) - case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident) - case let (.gp(bind), .none, .none): return .c(binding: bind) - default: return nil - } - case "n" where !isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .n($0) } - case "s" where !isGS2Header: return value.decodingBase64().map { .s($0) } - case "i" where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) } - case "p" where !isGS2Header: return value.decodingBase64().map { .p($0) } - case "v" where !isGS2Header: return value.decodingBase64().map { .v($0) } - case "e" where !isGS2Header: // TODO: actually map the specific enum string values - guard value.isValidScramValue else { return nil } - return String(bytes: value, encoding: .utf8).flatMap { SCRAMServerError(rawValue: $0) }.map { .e($0) } - - case "y" where isGS2Header && value.count == 0: return .gp(.unused) - case "n" where isGS2Header && value.count == 0: return .gp(.unsupported) - case "p" where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) } - case "a" where isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .a($0) } - case .none where isGS2Header: return .a(nil) + case UInt8(ascii: "m") where !isGS2Header: return .m(value) + case UInt8(ascii: "r") where !isGS2Header: return String(printableAscii: value).map { .r($0) } + case UInt8(ascii: "c") where !isGS2Header: + guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil } + guard (1...3).contains(parsedAttrs.count) else { return nil } + switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) { + case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident) + case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data)) + case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident) + case let (.gp(bind), .none, .none): return .c(binding: bind) + default: return nil + } + case UInt8(ascii: "n") where !isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .n($0) } + case UInt8(ascii: "s") where !isGS2Header: return value.decodingBase64().map { .s($0) } + case UInt8(ascii: "i") where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) } + case UInt8(ascii: "p") where !isGS2Header: return value.decodingBase64().map { .p($0) } + case UInt8(ascii: "v") where !isGS2Header: return value.decodingBase64().map { .v($0) } + case UInt8(ascii: "e") where !isGS2Header: // TODO: actually map the specific enum string values + guard value.isValidScramValue else { return nil } + return SCRAMServerError(rawValue: String(decoding: value, as: Unicode.UTF8.self)).flatMap { .e($0) } - default: - if isGS2Header { - return .gm(name + value) - } else { - guard value.count > 0, value.isValidScramValue else { return nil } - return .optional(name: CChar(name[0]), value: value) - } + case UInt8(ascii: "y") where isGS2Header && value.count == 0: return .gp(.unused) + case UInt8(ascii: "n") where isGS2Header && value.count == 0: return .gp(.unsupported) + case UInt8(ascii: "p") where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) } + case UInt8(ascii: "a") where isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .a($0) } + case .none where isGS2Header: return .a(nil) + + default: + if isGS2Header { + return .gm(name + value) + } else { + guard value.count > 0, value.isValidScramValue else { return nil } + return .optional(name: CChar(name[0]), value: value) + } } } @@ -230,45 +227,45 @@ fileprivate struct SCRAMMessageParser { for attribute in attributes { switch attribute { case .m(let value): - result.append("m"); result.append("="); result.append(contentsOf: value) + result.append(UInt8(ascii: "m")); result.append(.equals); result.append(contentsOf: value) case .r(let nonce): - result.append("r"); result.append("="); result.append(contentsOf: nonce.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "r")); result.append(.equals); result.append(contentsOf: nonce.utf8.map { UInt8($0) }) case .n(let name): - result.append("n"); result.append("="); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "n")); result.append(.equals); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) }) case .s(let salt): - result.append("s"); result.append("="); result.append(contentsOf: salt.encodingBase64()) + result.append(UInt8(ascii: "s")); result.append(.equals); result.append(contentsOf: salt.encodingBase64()) case .i(let count): - result.append("i"); result.append("="); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "i")); result.append(.equals); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) }) case .p(let proof): - result.append("p"); result.append("="); result.append(contentsOf: proof.encodingBase64()) + result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: proof.encodingBase64()) case .v(let signature): - result.append("v"); result.append("="); result.append(contentsOf: signature.encodingBase64()) + result.append(UInt8(ascii: "v")); result.append(.equals); result.append(contentsOf: signature.encodingBase64()) case .e(let error): - result.append("e"); result.append("="); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "e")); result.append(.equals); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) }) case .c(let binding, let identity): if isInitialGS2Header { switch binding { - case .unsupported: result.append("n") - case .unused: result.append("y") - case .bind(let name, _): result.append("p"); result.append("="); result.append(contentsOf: name.utf8.map { UInt8($0) }) + case .unsupported: result.append(UInt8(ascii: "n")) + case .unused: result.append(UInt8(ascii: "y")) + case .bind(let name, _): result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: name.utf8.map { UInt8($0) }) } - result.append(",") + result.append(.comma) if let identity = identity { - result.append("a"); result.append("="); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "a")); result.append(.equals); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) }) } - result.append(",") + result.append(.comma) } else { guard var partial = serialize([attribute], isInitialGS2Header: true) else { return nil } if case let .bind(_, data) = binding { guard let data = data else { return nil } partial.append(contentsOf: data) } - result.append("c"); result.append("="); result.append(contentsOf: partial.encodingBase64()) + result.append(UInt8(ascii: "c")); result.append(.equals); result.append(contentsOf: partial.encodingBase64()) } default: return nil } - result.append(",") + result.append(.comma) } return result.dropLast() } @@ -472,7 +469,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations) let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) - var authMessage = firstMessageBare; authMessage.append(","); authMessage.append(contentsOf: message); authMessage.append(","); authMessage.append(contentsOf: clientFinalNoProof) + var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) var clientProof = Array(clientKey) @@ -485,7 +482,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { } // Generate a `client-final-message` - var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(",") + var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(.comma) guard let proofPart = SCRAMMessageParser.serialize([.p(Array(clientProof))]) else { throw SASLAuthenticationError.genericAuthenticationFailure } clientFinalMessage.append(contentsOf: proofPart) @@ -590,7 +587,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { // Compute client signature let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) - var authMessage = clientBareFirstMessage; authMessage.append(","); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(","); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) + var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) // Recompute client key from signature and proof, verify match diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index ba3c6a3f..3e3c9d65 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -803,7 +803,7 @@ final class ConnectionPoolTests: XCTestCase { pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) - for (index, request) in requests.enumerated() { + for (_, request) in requests.enumerated() { let connection = try await request.future.success connections.append(connection) } diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index febeee37..d20032a8 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -2,7 +2,8 @@ import class Foundation.JSONEncoder import NIOCore @testable import PostgresNIO -extension ConnectionStateMachine.ConnectionAction: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { case (.read, read): @@ -47,7 +48,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { } } -extension ConnectionStateMachine.ConnectionAction.CleanUpContext: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol' +extension PostgresNIO.ConnectionStateMachine.ConnectionAction.CleanUpContext: Swift.Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { guard lhs.closePromise?.futureResult === rhs.closePromise?.futureResult else { return false @@ -96,13 +98,15 @@ extension ConnectionStateMachine { } } -extension PSQLError: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.PSQLError: Swift.Equatable { public static func == (lhs: PSQLError, rhs: PSQLError) -> Bool { return true } } -extension PSQLTask: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.PSQLTask: Swift.Equatable { public static func == (lhs: PSQLTask, rhs: PSQLTask) -> Bool { switch (lhs, rhs) { case (.extendedQuery(let lhs), .extendedQuery(let rhs)): diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index db31b98a..a90d1e93 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -113,7 +113,7 @@ class DataRowTests: XCTestCase { } } -extension DataRow: ExpressibleByArrayLiteral { +extension PostgresNIO.DataRow: Swift.ExpressibleByArrayLiteral { public typealias ArrayLiteralElement = PostgresEncodable public init(arrayLiteral elements: PostgresEncodable...) { From 85d189c461b96a73f42df7b61c9d16dd06f74bfa Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 17:24:55 +0100 Subject: [PATCH 223/292] Run queries directly on PostgresClient (#456) --- Sources/PostgresNIO/New/PSQLRowStream.swift | 27 +++++----- Sources/PostgresNIO/Pool/PostgresClient.swift | 52 +++++++++++++++++++ .../PostgresClientTests.swift | 37 +++++++++++++ 3 files changed, 104 insertions(+), 12 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 0255e462..b7f2d4fb 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -35,7 +35,7 @@ final class PSQLRowStream: @unchecked Sendable { case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) case consumed(Result) - case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource) + case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) } internal let rowDescription: [RowDescription.Column] @@ -75,7 +75,7 @@ final class PSQLRowStream: @unchecked Sendable { // MARK: Async Sequence - func asyncSequence() -> PostgresRowSequence { + func asyncSequence(onFinish: @escaping @Sendable () -> () = {}) -> PostgresRowSequence { self.eventLoop.preconditionInEventLoop() guard case .waitingForConsumer(let bufferState) = self.downstreamState else { @@ -95,13 +95,13 @@ final class PSQLRowStream: @unchecked Sendable { switch bufferState { case .streaming(let bufferedRows, let dataSource): let yieldResult = source.yield(contentsOf: bufferedRows) - self.downstreamState = .asyncSequence(source, dataSource) - + self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish) self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) case .finished(let buffer, let commandTag): _ = source.yield(contentsOf: buffer) source.finish() + onFinish() self.downstreamState = .consumed(.success(commandTag)) case .failure(let error): @@ -130,7 +130,7 @@ final class PSQLRowStream: @unchecked Sendable { case .consumed: break - case .asyncSequence(_, let dataSource): + case .asyncSequence(_, let dataSource, _): dataSource.request(for: self) } } @@ -147,9 +147,10 @@ final class PSQLRowStream: @unchecked Sendable { private func cancel0() { switch self.downstreamState { - case .asyncSequence(_, let dataSource): + case .asyncSequence(_, let dataSource, let onFinish): self.downstreamState = .consumed(.failure(CancellationError())) dataSource.cancel(for: self) + onFinish() case .consumed: return @@ -320,7 +321,7 @@ final class PSQLRowStream: @unchecked Sendable { // immediately request more dataSource.request(for: self) - case .asyncSequence(let consumer, let source): + case .asyncSequence(let consumer, let source, _): let yieldResult = consumer.yield(contentsOf: newRows) self.executeActionBasedOnYieldResult(yieldResult, source: source) @@ -359,10 +360,11 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .consumed(.success(commandTag)) promise.succeed(rows) - case .asyncSequence(let source, _): - source.finish() + case .asyncSequence(let source, _, let onFinish): self.downstreamState = .consumed(.success(commandTag)) - + source.finish() + onFinish() + case .consumed: break } @@ -384,9 +386,10 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .consumed(.failure(error)) promise.fail(error) - case .asyncSequence(let consumer, _): - consumer.finish(error) + case .asyncSequence(let consumer, _, let onFinish): self.downstreamState = .consumed(.failure(error)) + consumer.finish(error) + onFinish() case .consumed: break diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index fc5a5b00..5b1bfa38 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -290,6 +290,58 @@ public final class PostgresClient: Sendable { return try await closure(connection) } + /// Run a query on the Postgres server the client is connected to. + /// + /// - Parameters: + /// - query: The ``PostgresQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result. + /// The sequence be discarded. + @discardableResult + public func query( + _ query: PostgresQuery, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> PostgresRowSequence { + do { + guard query.binds.count <= Int(UInt16.max) else { + throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) + } + + let connection = try await self.leaseConnection() + + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(connection.id)" + + let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + logger: logger, + promise: promise + ) + + connection.channel.write(HandlerTask.extendedQuery(context), promise: nil) + + promise.futureResult.whenFailure { _ in + self.pool.releaseConnection(connection) + } + + return try await promise.futureResult.map { + $0.asyncSequence(onFinish: { + self.pool.releaseConnection(connection) + }) + }.get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = query + throw error // rethrow with more metadata + } + } + /// The client's run method. Users must call this function in order to start the client's background task processing /// like creating and destroying connections and running timers. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index b1e7f9a8..4f22517e 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -41,6 +41,43 @@ final class PostgresClientTests: XCTestCase { taskGroup.cancelAll() } } + + func testQueryDirectly() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + for i in 0..<10000 { + taskGroup.addTask { + do { + try await client.query("SELECT 1", logger: logger) + logger.info("Success", metadata: ["run": "\(i)"]) + } catch { + XCTFail("Unexpected error: \(error)") + } + } + } + + for _ in 0..<10000 { + _ = await taskGroup.nextResult()! + } + + taskGroup.cancelAll() + } + } + } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) From 0679ede84f4c628f4d60810c32a33ced02e178ea Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 17:50:15 +0100 Subject: [PATCH 224/292] Fix prepared statements (#455) --- .../Connection/PostgresConnection.swift | 9 ++- .../ConnectionStateMachine.swift | 8 +- .../ExtendedQueryStateMachine.swift | 14 ++-- Sources/PostgresNIO/New/PSQLTask.swift | 14 +++- .../New/PostgresChannelHandler.swift | 10 ++- .../PostgresNIO/New/PreparedStatement.swift | 23 +++++- Tests/IntegrationTests/AsyncTests.swift | 81 +++++++++++++++++++ .../PrepareStatementStateMachineTests.swift | 12 +-- .../PreparedStatementStateMachineTests.swift | 1 + .../ConnectionAction+TestUtils.swift | 4 +- .../New/PostgresConnectionTests.swift | 8 +- 11 files changed, 150 insertions(+), 34 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index f79a5555..eb9dc791 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -234,6 +234,7 @@ public final class PostgresConnection: @unchecked Sendable { let context = ExtendedQueryContext( name: name, query: query, + bindingDataTypes: [], logger: logger, promise: promise ) @@ -472,9 +473,10 @@ extension PostgresConnection { let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( - name: String(reflecting: Statement.self), + name: Statement.name, sql: Statement.sql, bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, logger: logger, promise: promise )) @@ -493,10 +495,10 @@ extension PostgresConnection { ) throw error // rethrow with more metadata } - } /// Execute a prepared statement, taking care of the preparation when necessary + @_disfavoredOverload public func execute( _ preparedStatement: Statement, logger: Logger, @@ -506,9 +508,10 @@ extension PostgresConnection { let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( - name: String(reflecting: Statement.self), + name: Statement.name, sql: Statement.sql, bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, logger: logger, promise: promise )) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 8c3252de..9d264bcc 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -97,7 +97,7 @@ struct ConnectionStateMachine { case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?) // Prepare statement actions - case sendParseDescribeSync(name: String, query: String) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) case failPreparedStatementCreation(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) @@ -587,7 +587,7 @@ struct ConnectionStateMachine { switch queryContext.query { case .executeStatement(_, let promise), .unnamed(_, let promise): return .failQuery(promise, with: psqlErrror, cleanupContext: nil) - case .prepareStatement(_, _, let promise): + case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } case .closeCommand(let closeContext): @@ -1057,8 +1057,8 @@ extension ConnectionStateMachine { return .read case .wait: return .wait - case .sendParseDescribeSync(name: let name, query: let query): - return .sendParseDescribeSync(name: name, query: query) + case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes): + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) case .succeedPreparedStatementCreation(let promise, with: let rowDescription): return .succeedPreparedStatementCreation(promise, with: rowDescription) case .failPreparedStatementCreation(let promise, with: let error): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 3a84031b..78f0d202 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -26,7 +26,7 @@ struct ExtendedQueryStateMachine { enum Action { case sendParseDescribeBindExecuteSync(PostgresQuery) - case sendParseDescribeSync(name: String, query: String) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) case sendBindExecuteSync(PSQLExecuteStatement) // --- general actions @@ -79,10 +79,10 @@ struct ExtendedQueryStateMachine { return .sendBindExecuteSync(prepared) } - case .prepareStatement(let name, let query, _): + case .prepareStatement(let name, let query, let bindingDataTypes, _): return self.avoidingStateMachineCoW { state -> Action in state = .messagesSent(queryContext) - return .sendParseDescribeSync(name: name, query: query) + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) } } } @@ -107,7 +107,7 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: .queryCancelled) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) } @@ -165,7 +165,7 @@ struct ExtendedQueryStateMachine { return .wait } - case .prepareStatement(_, _, let promise): + case .prepareStatement(_, _, _, let promise): return self.avoidingStateMachineCoW { state -> Action in state = .noDataMessageReceived(queryContext) return .succeedPreparedStatementCreation(promise, with: nil) @@ -200,7 +200,7 @@ struct ExtendedQueryStateMachine { case .unnamed, .executeStatement: return .wait - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .succeedPreparedStatementCreation(eventLoopPromise, with: rowDescription) } } @@ -477,7 +477,7 @@ struct ExtendedQueryStateMachine { switch context.query { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: error) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: error) } } diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 6308a5b3..363f9394 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -21,7 +21,7 @@ enum PSQLTask { eventLoopPromise.fail(error) case .executeStatement(_, let eventLoopPromise): eventLoopPromise.fail(error) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): eventLoopPromise.fail(error) } @@ -35,7 +35,7 @@ final class ExtendedQueryContext { enum Query { case unnamed(PostgresQuery, EventLoopPromise) case executeStatement(PSQLExecuteStatement, EventLoopPromise) - case prepareStatement(name: String, query: String, EventLoopPromise) + case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise) } let query: Query @@ -62,10 +62,11 @@ final class ExtendedQueryContext { init( name: String, query: String, + bindingDataTypes: [PostgresDataType], logger: Logger, promise: EventLoopPromise ) { - self.query = .prepareStatement(name: name, query: query, promise) + self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise) self.logger = logger } } @@ -73,6 +74,7 @@ final class ExtendedQueryContext { final class PreparedStatementContext: Sendable { let name: String let sql: String + let bindingDataTypes: [PostgresDataType] let bindings: PostgresBindings let logger: Logger let promise: EventLoopPromise @@ -81,12 +83,18 @@ final class PreparedStatementContext: Sendable { name: String, sql: String, bindings: PostgresBindings, + bindingDataTypes: [PostgresDataType], logger: Logger, promise: EventLoopPromise ) { self.name = name self.sql = sql self.bindings = bindings + if bindingDataTypes.isEmpty { + self.bindingDataTypes = bindings.metadata.map(\.dataType) + } else { + self.bindingDataTypes = bindingDataTypes + } self.logger = logger self.promise = promise } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 54ae0fc9..32dea4a5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -345,8 +345,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.closeConnectionAndCleanup(cleanupContext, context: context) case .fireChannelInactive: context.fireChannelInactive() - case .sendParseDescribeSync(let name, let query): - self.sendParseDecribeAndSyncMessage(statementName: name, query: query, context: context) + case .sendParseDescribeSync(let name, let query, let bindingDataTypes): + self.sendParseDescribeAndSyncMessage(statementName: name, query: query, bindingDataTypes: bindingDataTypes, context: context) case .sendBindExecuteSync(let executeStatement): self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) case .sendParseDescribeBindExecuteSync(let query): @@ -489,13 +489,14 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } - private func sendParseDecribeAndSyncMessage( + private func sendParseDescribeAndSyncMessage( statementName: String, query: String, + bindingDataTypes: [PostgresDataType], context: ChannelHandlerContext ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") - self.encoder.parse(preparedStatementName: statementName, query: query, parameters: []) + self.encoder.parse(preparedStatementName: statementName, query: query, parameters: bindingDataTypes) self.encoder.describePreparedStatement(statementName) self.encoder.sync() context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) @@ -724,6 +725,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { return .extendedQuery(.init( name: preparedStatement.name, query: preparedStatement.sql, + bindingDataTypes: preparedStatement.bindingDataTypes, logger: preparedStatement.logger, promise: promise )) diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift index 1e0b5d5a..21165388 100644 --- a/Sources/PostgresNIO/New/PreparedStatement.swift +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -26,15 +26,36 @@ /// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, /// which will take care of preparing the statement on the server side and executing it. public protocol PostgresPreparedStatement: Sendable { + /// The prepared statements name. + /// + /// > Note: There is a default implementation that returns the implementor's name. + static var name: String { get } + /// The type rows returned by the statement will be decoded into associatedtype Row /// The SQL statement to prepare on the database server. static var sql: String { get } - /// Make the bindings to provided concrete values to use when executing the prepared SQL statement + /// The postgres data types of the values that are bind when this statement is executed. + /// + /// If an empty array is returned the datatypes are inferred from the ``PostgresBindings`` returned + /// from ``PostgresPreparedStatement/makeBindings()``. + /// + /// > Note: There is a default implementation that returns an empty array, which will lead to + /// automatic inference. + static var bindingDataTypes: [PostgresDataType] { get } + + /// Make the bindings to provided concrete values to use when executing the prepared SQL statement. + /// The order must match ``PostgresPreparedStatement/bindingDataTypes-4b6tx``. func makeBindings() throws -> PostgresBindings /// Decode a row returned by the database into an instance of `Row` func decodeRow(_ row: PostgresRow) throws -> Row } + +extension PostgresPreparedStatement { + public static var name: String { String(reflecting: self) } + + public static var bindingDataTypes: [PostgresDataType] { [] } +} diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 91b5656c..75e5b6ba 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -358,6 +358,87 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } } + + static let preparedStatementTestTable = "AsyncTestPreparedStatementTestTable" + func testPreparedStatementWithIntegerBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID NOT NULL + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(uuid)" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } } extension XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 6a08afeb..547f5cdf 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -12,11 +12,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"SELECT id FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -38,11 +38,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -60,11 +60,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift index ab77a57c..f6c1ddf7 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -152,6 +152,7 @@ class PreparedStatementStateMachineTests: XCTestCase { name: "test", sql: "INSERT INTO test_table (column1) VALUES (1)", bindings: PostgresBindings(), + bindingDataTypes: [], logger: .psqlTest, promise: promise ) diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index d20032a8..9a1224d8 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -36,8 +36,8 @@ extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { return lhsBuffer == rhsBuffer && lhsCommandTag == rhsCommandTag case (.forwardStreamError(let lhsError, let lhsRead, let lhsCleanupContext), .forwardStreamError(let rhsError , let rhsRead, let rhsCleanupContext)): return lhsError == rhsError && lhsRead == rhsRead && lhsCleanupContext == rhsCleanupContext - case (.sendParseDescribeSync(let lhsName, let lhsQuery), .sendParseDescribeSync(let rhsName, let rhsQuery)): - return lhsName == rhsName && lhsQuery == rhsQuery + case (.sendParseDescribeSync(let lhsName, let lhsQuery, let lhsDataTypes), .sendParseDescribeSync(let rhsName, let rhsQuery, let rhsDataTypes)): + return lhsName == rhsName && lhsQuery == rhsQuery && lhsDataTypes == rhsDataTypes case (.succeedPreparedStatementCreation(let lhsPromise, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsPromise, let rhsRowDescription)): return lhsPromise.futureResult === rhsPromise.futureResult && lhsRowDescription == rhsRowDescription case (.fireChannelInactive, .fireChannelInactive): diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 82baf914..a773cf2c 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -337,7 +337,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -393,7 +393,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -487,7 +487,7 @@ class PostgresConnectionTests: XCTestCase { // The channel deduplicates prepare requests, we're going to see only one of them let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -555,7 +555,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } From 17b23b1a24f0e7b451be6ae27d30f29a4c29099f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 19:26:49 +0100 Subject: [PATCH 225/292] Adds prepared statement support to client (#459) --- Sources/PostgresNIO/Pool/PostgresClient.swift | 42 ++++++++++ .../PostgresClientTests.swift | 81 ++++++++++++++++++- .../New/PostgresConnectionTests.swift | 2 +- 3 files changed, 121 insertions(+), 4 deletions(-) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 5b1bfa38..4a576085 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -342,6 +342,48 @@ public final class PostgresClient: Sendable { } } + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { + let bindings = try preparedStatement.makeBindings() + + do { + let connection = try await self.leaseConnection() + + let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, + logger: logger, + promise: promise + )) + connection.channel.write(task, promise: nil) + + promise.futureResult.whenFailure { _ in + self.pool.releaseConnection(connection) + } + + return try await promise.futureResult + .map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) } + .get() + .map { try preparedStatement.decodeRow($0) } + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } + /// The client's run method. Users must call this function in order to start the client's background task processing /// like creating and destroying connections and running timers. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 4f22517e..9115dc82 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -25,16 +25,17 @@ final class PostgresClientTests: XCTestCase { await client.run() } - for i in 0..<10000 { + let iterations = 1000 + + for i in 0.. PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + for try await (id, uuid) in try await client.execute(Example(id: 200), logger: logger) { + logger.info("id: \(id), uuid: \(uuid.uuidString)") + } + + try await client.query( + """ + DROP TABLE "\(unescaped: tableName)"; + """, + logger: logger + ) + + taskGroup.cancelAll() + } + } catch { + XCTFail("Unexpected error: \(String(reflecting: error))") + } + } } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index a773cf2c..f2cd96f8 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -155,7 +155,7 @@ class PostgresConnectionTests: XCTestCase { _ = try await iterator.next() XCTFail("Did not expect to not throw") } catch { - print(error) + self.logger.error("error", metadata: ["error": "\(error)"]) } } From c75349fadbffaba06dedaf6c0eb936a4edff5dc5 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 22 Feb 2024 14:00:01 +0100 Subject: [PATCH 226/292] PostgresClient implements ServiceLifecycle's Service (#457) --- Package.swift | 2 ++ Sources/PostgresNIO/Pool/PostgresClient.swift | 27 ++++++++++--------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/Package.swift b/Package.swift index 814335bd..4d008371 100644 --- a/Package.swift +++ b/Package.swift @@ -22,6 +22,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), + .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.4.1"), ], targets: [ .target( @@ -39,6 +40,7 @@ let package = Package( .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOFoundationCompat", package: "swift-nio"), + .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), ] ), .target( diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 4a576085..2c21cce7 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -2,6 +2,7 @@ import NIOCore import NIOSSL import Atomics import Logging +import ServiceLifecycle import _ConnectionPoolModule /// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's @@ -17,23 +18,22 @@ import _ConnectionPoolModule /// client.run() // !important /// } /// -/// taskGroup.addTask { -/// client.withConnection { connection in -/// do { -/// let rows = try await connection.query("SELECT userID, name, age FROM users;") -/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) { -/// // do something with the values -/// } -/// } catch { -/// // handle errors -/// } +/// do { +/// let rows = try await connection.query("SELECT userID, name, age FROM users;") +/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) { +/// // do something with the values /// } +/// } catch { +/// // handle errors /// } +/// +/// // shutdown the client +/// taskGroup.cancelAll() /// } /// ``` @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @_spi(ConnectionPool) -public final class PostgresClient: Sendable { +public final class PostgresClient: Sendable, ServiceLifecycle.Service { public struct Configuration: Sendable { public struct TLS: Sendable { enum Base { @@ -391,7 +391,10 @@ public final class PostgresClient: Sendable { public func run() async { let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") - await self.pool.run() + + await cancelOnGracefulShutdown { + await self.pool.run() + } } // MARK: - Private Methods - From 7632411e5964f0fb8ffa92acd5cd7b6be46625a6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 22 Feb 2024 15:43:51 +0100 Subject: [PATCH 227/292] Make PostgresClient API (#460) --- Sources/PostgresNIO/Pool/PostgresClient.swift | 30 ++++++++++++++----- .../PostgresClientTests.swift | 2 +- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 2c21cce7..865dafc8 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -8,11 +8,11 @@ import _ConnectionPoolModule /// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's /// behavior. /// -/// > Important: +/// > Warning: /// The client can only lease connections if the user is running the client's ``run()`` method in a long running task: /// /// ```swift -/// let client = PostgresClient(configuration: configuration, logger: logger) +/// let client = PostgresClient(configuration: configuration) /// await withTaskGroup(of: Void.self) { /// taskGroup.addTask { /// client.run() // !important @@ -32,7 +32,6 @@ import _ConnectionPoolModule /// } /// ``` @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -@_spi(ConnectionPool) public final class PostgresClient: Sendable, ServiceLifecycle.Service { public struct Configuration: Sendable { public struct TLS: Sendable { @@ -246,8 +245,22 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let factory: ConnectionFactory let runningAtomic = ManagedAtomic(false) let backgroundLogger: Logger - + + /// Creates a new ``PostgresClient``, that does not log any background information. + /// Don't forget to run ``run()`` the client in a long running task. + /// + /// - Parameters: + /// - configuration: The client's configuration. See ``Configuration`` for details. + /// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``. + public convenience init( + configuration: Configuration, + eventLoopGroup: any EventLoopGroup = PostgresClient.defaultEventLoopGroup + ) { + self.init(configuration: configuration, eventLoopGroup: eventLoopGroup, backgroundLogger: Self.loggingDisabled) + } + /// Creates a new ``PostgresClient``. Don't forget to run ``run()`` the client in a long running task. + /// /// - Parameters: /// - configuration: The client's configuration. See ``Configuration`` for details. /// - eventLoopGroup: The underlying NIO `EventLoopGroup`. Defaults to ``defaultEventLoopGroup``. @@ -302,10 +315,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { @discardableResult public func query( _ query: PostgresQuery, - logger: Logger, + logger: Logger? = nil, file: String = #fileID, line: Int = #line ) async throws -> PostgresRowSequence { + let logger = logger ?? Self.loggingDisabled do { guard query.binds.count <= Int(UInt16.max) else { throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) @@ -345,11 +359,12 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// Execute a prepared statement, taking care of the preparation when necessary public func execute( _ preparedStatement: Statement, - logger: Logger, + logger: Logger? = nil, file: String = #fileID, line: Int = #line ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { let bindings = try preparedStatement.makeBindings() + let logger = logger ?? Self.loggingDisabled do { let connection = try await self.leaseConnection() @@ -412,6 +427,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { public static var defaultEventLoopGroup: EventLoopGroup { PostgresConnection.defaultEventLoopGroup } + + static let loggingDisabled = Logger(label: "Postgres-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @@ -444,7 +461,6 @@ extension ConnectionPoolConfiguration { } } -@_spi(ConnectionPool) extension PostgresConnection: PooledConnection { public func close() { self.channel.close(mode: .all, promise: nil) diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 9115dc82..d6d89dc3 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -27,7 +27,7 @@ final class PostgresClientTests: XCTestCase { let iterations = 1000 - for i in 0.. Date: Thu, 22 Feb 2024 16:14:04 +0100 Subject: [PATCH 228/292] Improve docs before releasing PostgresClient (#461) --- README.md | 66 ++++++------- Snippets/Birthdays.swift | 74 +++++++++++++++ Snippets/PostgresClient.swift | 40 ++++++++ Sources/PostgresNIO/Docs.docc/coding.md | 39 ++++++++ Sources/PostgresNIO/Docs.docc/deprecated.md | 43 +++++++++ Sources/PostgresNIO/Docs.docc/index.md | 93 +++++++------------ Sources/PostgresNIO/Docs.docc/listen.md | 9 ++ Sources/PostgresNIO/Docs.docc/migrations.md | 12 --- .../Docs.docc/prepared-statement.md | 7 ++ .../PostgresNIO/Docs.docc/running-queries.md | 27 ++++++ Sources/PostgresNIO/Pool/PostgresClient.swift | 62 ++++++++----- 11 files changed, 336 insertions(+), 136 deletions(-) create mode 100644 Snippets/Birthdays.swift create mode 100644 Snippets/PostgresClient.swift create mode 100644 Sources/PostgresNIO/Docs.docc/coding.md create mode 100644 Sources/PostgresNIO/Docs.docc/deprecated.md create mode 100644 Sources/PostgresNIO/Docs.docc/listen.md create mode 100644 Sources/PostgresNIO/Docs.docc/prepared-statement.md create mode 100644 Sources/PostgresNIO/Docs.docc/running-queries.md diff --git a/README.md b/README.md index ef1dc4ec..c2dc545e 100644 --- a/README.md +++ b/README.md @@ -28,15 +28,14 @@ Features: - A [`PostgresConnection`] which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server +- A [`PostgresClient`] which pools and manages connections - An async/await interface that supports backpressure - Automatic conversions between Swift primitive types and the Postgres wire format -- Integrated with the Swift server ecosystem, including use of [SwiftLog]. +- Integrated with the Swift server ecosystem, including use of [SwiftLog] and [ServiceLifecycle]. - Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) - Support for `Network.framework` when available (e.g. on Apple platforms) - Supports running on Unix Domain Sockets -PostgresNIO does not provide a `ConnectionPool` as of today, but this is a [feature high on our list](https://github.com/vapor/postgres-nio/issues/256). If you need a `ConnectionPool` today, please have a look at Vapor's [PostgresKit]. - ## API Docs Check out the [PostgresNIO API docs][Documentation] for a @@ -44,13 +43,16 @@ detailed look at all of the classes, structs, protocols, and more. ## Getting started +Interested in an example? We prepared a simple [Birthday example](/vapor/postgres-nio/tree/main/Snippets/Birthdays.swift) +in the Snippets folder. + #### Adding the dependency Add `PostgresNIO` as dependency to your `Package.swift`: ```swift dependencies: [ - .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.14.0"), + .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.21.0"), ... ] ``` @@ -64,14 +66,14 @@ Add `PostgresNIO` to the target you want to use it in: ] ``` -#### Creating a connection +#### Creating a client -To create a connection, first create a connection configuration object: +To create a [`PostgresClient`], which pools connections for you, first create a configuration object: ```swift import PostgresNIO -let config = PostgresConnection.Configuration( +let config = PostgresClient.Configuration( host: "localhost", port: 5432, username: "my_username", @@ -81,50 +83,35 @@ let config = PostgresConnection.Configuration( ) ``` -To create a connection we need a [`Logger`], that is used to log connection background events. - +Next you can create you client with it: ```swift -import Logging - -let logger = Logger(label: "postgres-logger") +let client = PostgresClient(configuration: config) ``` -Now we can put it together: - +Once you have create your client, you must [`run()`] it: ```swift -import PostgresNIO -import Logging - -let logger = Logger(label: "postgres-logger") - -let config = PostgresConnection.Configuration( - host: "localhost", - port: 5432, - username: "my_username", - password: "my_password", - database: "my_database", - tls: .disable -) +await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } -let connection = try await PostgresConnection.connect( - configuration: config, - id: 1, - logger: logger -) + // You can use the client while the `client.run()` method is not cancelled. -// Close your connection once done -try await connection.close() + // To shutdown the client, cancel its run method, by cancelling the taskGroup. + taskGroup.cancelAll() +} ``` #### Querying -Once a connection is established, queries can be sent to the server. This is very straightforward: +Once a client is running, queries can be sent to the server. This is straightforward: ```swift -let rows = try await connection.query("SELECT id, username, birthday FROM users", logger: logger) +let rows = try await client.query("SELECT id, username, birthday FROM users") ``` -The query will return a [`PostgresRowSequence`], which is an AsyncSequence of [`PostgresRow`]s. The rows can be iterated one-by-one: +The query will return a [`PostgresRowSequence`], which is an AsyncSequence of [`PostgresRow`]s. +The rows can be iterated one-by-one: ```swift for try await row in rows { @@ -160,7 +147,7 @@ Sending parameterized queries to the database is also supported (in the coolest let id = 1 let username = "fancyuser" let birthday = Date() -try await connection.query(""" +try await client.query(""" INSERT INTO users (id, username, birthday) VALUES (\(id), \(username), \(birthday)) """, logger: logger @@ -184,6 +171,8 @@ Please see [SECURITY.md] for details on the security process. [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection +[`PostgresClient`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresclient +[`run()`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresclient/run() [`query(_:logger:)`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection/query(_:logger:file:line:)-9mkfn [`PostgresQuery`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresquery [`PostgresRow`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresrow @@ -193,4 +182,5 @@ Please see [SECURITY.md] for details on the security process. [SwiftNIO]: https://github.com/apple/swift-nio [PostgresKit]: https://github.com/vapor/postgres-kit [SwiftLog]: https://github.com/apple/swift-log +[ServiceLifecycle]: https://github.com/swift-server/swift-service-lifecycle [`Logger`]: https://apple.github.io/swift-log/docs/current/Logging/Structs/Logger.html diff --git a/Snippets/Birthdays.swift b/Snippets/Birthdays.swift new file mode 100644 index 00000000..60516aa1 --- /dev/null +++ b/Snippets/Birthdays.swift @@ -0,0 +1,74 @@ +import PostgresNIO +import Foundation + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +enum Birthday { + static func main() async throws { + // 1. Create a configuration to match server's parameters + let config = PostgresClient.Configuration( + host: "localhost", + port: 5432, + username: "test_username", + password: "test_password", + database: "test_database", + tls: .disable + ) + + // 2. Create a client + let client = PostgresClient(configuration: config) + + // 3. Run the client + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } + + // 4. Create a friends table to store data into + try await client.query(""" + CREATE TABLE IF NOT EXISTS "friends" ( + id SERIAL PRIMARY KEY, + given_name TEXT, + last_name TEXT, + birthday TIMESTAMP WITH TIME ZONE + ) + """ + ) + + // 5. Create a Swift friend representation + struct Friend { + var firstName: String + var lastName: String + var birthday: Date + } + + // 6. Create John Appleseed with special birthday + let dateFormatter = DateFormatter() + dateFormatter.dateFormat = "yyyy-MM-dd" + let johnsBirthday = dateFormatter.date(from: "1960-09-26")! + let friend = Friend(firstName: "Hans", lastName: "Müller", birthday: johnsBirthday) + + // 7. Store friend into the database + try await client.query(""" + INSERT INTO "friends" (given_name, last_name, birthday) + VALUES + (\(friend.firstName), \(friend.lastName), \(friend.birthday)); + """ + ) + + // 8. Query database for the friend we just inserted + let rows = try await client.query(""" + SELECT id, given_name, last_name, birthday FROM "friends" WHERE given_name = \(friend.firstName) + """ + ) + + // 9. Iterate the returned rows, decoding the rows into Swift primitives + for try await (id, firstName, lastName, birthday) in rows.decode((Int, String, String, Date).self) { + print("\(id) | \(firstName) \(lastName), \(birthday)") + } + + // 10. Shutdown the client, by cancelling its run method, through cancelling the taskGroup. + taskGroup.cancelAll() + } + } +} + diff --git a/Snippets/PostgresClient.swift b/Snippets/PostgresClient.swift new file mode 100644 index 00000000..9bfacc28 --- /dev/null +++ b/Snippets/PostgresClient.swift @@ -0,0 +1,40 @@ +import PostgresNIO +import struct Foundation.UUID + +@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) +enum Runner { + static func main() async throws { + +// snippet.configuration +let config = PostgresClient.Configuration( + host: "localhost", + port: 5432, + username: "my_username", + password: "my_password", + database: "my_database", + tls: .disable +) +// snippet.end + +// snippet.makeClient +let client = PostgresClient(configuration: config) +// snippet.end + + } + + static func runAndCancel(client: PostgresClient) async { +// snippet.run +await withTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() // !important + } + + // You can use the client while the `client.run()` method is not cancelled. + + // To shutdown the client, cancel its run method, by cancelling the taskGroup. + taskGroup.cancelAll() +} +// snippet.end + } +} + diff --git a/Sources/PostgresNIO/Docs.docc/coding.md b/Sources/PostgresNIO/Docs.docc/coding.md new file mode 100644 index 00000000..3bcc4a7e --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/coding.md @@ -0,0 +1,39 @@ +# PostgreSQL data types + +Translate Swift data types to Postgres data types and vica versa. Learn how to write translations +for your own custom Swift types. + +## Topics + +### Essentials + +- ``PostgresCodable`` +- ``PostgresDataType`` +- ``PostgresFormat`` +- ``PostgresNumeric`` + +### Encoding + +- ``PostgresEncodable`` +- ``PostgresNonThrowingEncodable`` +- ``PostgresDynamicTypeEncodable`` +- ``PostgresThrowingDynamicTypeEncodable`` +- ``PostgresArrayEncodable`` +- ``PostgresRangeEncodable`` +- ``PostgresRangeArrayEncodable`` +- ``PostgresEncodingContext`` + +### Decoding + +- ``PostgresDecodable`` +- ``PostgresArrayDecodable`` +- ``PostgresRangeDecodable`` +- ``PostgresRangeArrayDecodable`` +- ``PostgresDecodingContext`` + +### JSON + +- ``PostgresJSONEncoder`` +- ``PostgresJSONDecoder`` + + diff --git a/Sources/PostgresNIO/Docs.docc/deprecated.md b/Sources/PostgresNIO/Docs.docc/deprecated.md new file mode 100644 index 00000000..a29465f6 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/deprecated.md @@ -0,0 +1,43 @@ +# Deprecations + +`PostgresNIO` follows SemVer 2.0.0. Learn which APIs are considered deprecated and how to migrate to +their replacements. + +``PostgresNIO`` reached 1.0 in April 2020. Since then the maintainers have been hard at work to +guarantee API stability. However as the Swift and Swift on server ecosystem have matured approaches +have changed. The introduction of structured concurrency changed what developers expect from a +modern Swift library. Because of this ``PostgresNIO`` added various APIs that embrace the new Swift +patterns. This means however, that PostgresNIO still offers APIs that have fallen out of favor. +Those are documented here. All those APIs will be removed once the maintainers release the next +major version. The maintainers recommend all adopters to move of those APIs sooner rather than +later. + +## Topics + +### Migrate of deprecated APIs + +- + +### Deprecated APIs + +These types are already deprecated or will be deprecated in the near future. All of them will be +removed from the public API with the next major release. + +- ``PostgresDatabase`` +- ``PostgresData`` +- ``PostgresDataConvertible`` +- ``PostgresQueryResult`` +- ``PostgresJSONCodable`` +- ``PostgresJSONBCodable`` +- ``PostgresMessageEncoder`` +- ``PostgresMessageDecoder`` +- ``PostgresRequest`` +- ``PostgresMessage`` +- ``PostgresMessageType`` +- ``PostgresFormatCode`` +- ``PostgresListenContext`` +- ``PreparedQuery`` +- ``SASLAuthenticationManager`` +- ``SASLAuthenticationMechanism`` +- ``SASLAuthenticationError`` +- ``SASLAuthenticationStepResult`` diff --git a/Sources/PostgresNIO/Docs.docc/index.md b/Sources/PostgresNIO/Docs.docc/index.md index ebe27cd0..6355a7a4 100644 --- a/Sources/PostgresNIO/Docs.docc/index.md +++ b/Sources/PostgresNIO/Docs.docc/index.md @@ -8,80 +8,51 @@ ## Overview -Features: - -- A ``PostgresConnection`` which allows you to connect to, authorize with, query, and retrieve results from a PostgreSQL server using [SwiftNIO]. -- An async/await interface that supports backpressure -- Automatic conversions between Swift primitive types and the Postgres wire format -- Integrated with the Swift server ecosystem, including use of [SwiftLog]. -- Designed to run efficiently on all supported platforms (tested extensively on Linux and Darwin systems) -- Support for `Network.framework` when available (e.g. on Apple platforms) +``PostgresNIO`` allows you to connect to, authorize with, query, and retrieve results from a +PostgreSQL server. PostgreSQL is an open source relational database. + +Use a ``PostgresConnection`` to create a connection to the PostgreSQL server. You can then use it to +run queries and prepared statements against the server. ``PostgresConnection`` also supports +PostgreSQL's Listen & Notify API. + +Developers, who don't want to manage connections themselves, can use the ``PostgresClient``, which +offers the same functionality as ``PostgresConnection``. ``PostgresClient`` +pools connections for rapid connection reuse and hides the complexities of connection +management from the user, allowing developers to focus on their SQL queries. ``PostgresClient`` +implements the `Service` protocol from Service Lifecycle allowing an easy adoption in Swift server +applications. + +``PostgresNIO`` embraces Swift structured concurrency, offering async/await APIs which handle +task cancellation. The query interface makes use of backpressure to ensure that memory can not grow +unbounded for queries that return thousands of rows. + +``PostgresNIO`` runs efficiently on Linux and Apple platforms. On Apple platforms developers can +configure ``PostgresConnection`` to use `Network.framework` as the underlying transport framework. ## Topics -### Articles - -- - -### Connections +### Essentials +- ``PostgresClient`` +- ``PostgresClient/Configuration`` - ``PostgresConnection`` +- -### Querying - -- ``PostgresQuery`` -- ``PostgresBindings`` -- ``PostgresRow`` -- ``PostgresRowSequence`` -- ``PostgresRandomAccessRow`` -- ``PostgresCell`` -- ``PreparedQuery`` -- ``PostgresQueryMetadata`` - -### Encoding and Decoding +### Advanced -- ``PostgresEncodable`` -- ``PostgresEncodingContext`` -- ``PostgresDecodable`` -- ``PostgresDecodingContext`` -- ``PostgresArrayEncodable`` -- ``PostgresArrayDecodable`` -- ``PostgresJSONEncoder`` -- ``PostgresJSONDecoder`` -- ``PostgresDataType`` -- ``PostgresFormat`` -- ``PostgresNumeric`` - -### Notifications - -- ``PostgresListenContext`` +- +- +- ### Errors - ``PostgresError`` - ``PostgresDecodingError`` +- ``PSQLError`` + +### Deprecations -### Deprecated - -These types are already deprecated or will be deprecated in the near future. All of them will be -removed from the public API with the next major release. - -- ``PostgresDatabase`` -- ``PostgresData`` -- ``PostgresDataConvertible`` -- ``PostgresQueryResult`` -- ``PostgresJSONCodable`` -- ``PostgresJSONBCodable`` -- ``PostgresMessageEncoder`` -- ``PostgresMessageDecoder`` -- ``PostgresRequest`` -- ``PostgresMessage`` -- ``PostgresMessageType`` -- ``PostgresFormatCode`` -- ``SASLAuthenticationManager`` -- ``SASLAuthenticationMechanism`` -- ``SASLAuthenticationError`` -- ``SASLAuthenticationStepResult`` +- [SwiftNIO]: https://github.com/apple/swift-nio [SwiftLog]: https://github.com/apple/swift-log diff --git a/Sources/PostgresNIO/Docs.docc/listen.md b/Sources/PostgresNIO/Docs.docc/listen.md new file mode 100644 index 00000000..10c5d8bf --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/listen.md @@ -0,0 +1,9 @@ +# Listen & Notify + +``PostgresNIO`` supports PostgreSQL's listen and notify API. Learn how to listen for changes and +notify other listeners. + +## Topics + +- ``PostgresNotification`` +- ``PostgresNotificationSequence`` diff --git a/Sources/PostgresNIO/Docs.docc/migrations.md b/Sources/PostgresNIO/Docs.docc/migrations.md index 7185ba06..3a7c634a 100644 --- a/Sources/PostgresNIO/Docs.docc/migrations.md +++ b/Sources/PostgresNIO/Docs.docc/migrations.md @@ -87,16 +87,4 @@ connection.query("SELECT id, name, email, age FROM users").whenComplete { } ``` -## Topics - -### Relevant types - -- ``PostgresConnection`` -- ``PostgresQuery`` -- ``PostgresBindings`` -- ``PostgresRow`` -- ``PostgresRandomAccessRow`` -- ``PostgresEncodable`` -- ``PostgresDecodable`` - [`1.9.0`]: https://github.com/vapor/postgres-nio/releases/tag/1.9.0 diff --git a/Sources/PostgresNIO/Docs.docc/prepared-statement.md b/Sources/PostgresNIO/Docs.docc/prepared-statement.md new file mode 100644 index 00000000..ff4b1c62 --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/prepared-statement.md @@ -0,0 +1,7 @@ +# Boosting Performance with Prepared Statements + +Improve performance by leveraging PostgreSQL's prepared statements. + +## Topics + +- ``PostgresPreparedStatement`` diff --git a/Sources/PostgresNIO/Docs.docc/running-queries.md b/Sources/PostgresNIO/Docs.docc/running-queries.md new file mode 100644 index 00000000..b2c4586f --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/running-queries.md @@ -0,0 +1,27 @@ +# Running Queries + +Interact with the PostgreSQL database by running Queries. + +## Overview + + + +You interact with the Postgres database by running SQL [Queries]. + + + +``PostgresQuery`` conforms to + + +## Topics + +- ``PostgresQuery`` +- ``PostgresBindings`` +- ``PostgresRow`` +- ``PostgresRowSequence`` +- ``PostgresRandomAccessRow`` +- ``PostgresCell`` +- ``PostgresQueryMetadata`` + +[Queries]: doc:PostgresQuery +[`ExpressibleByStringInterpolation`]: https://developer.apple.com/documentation/swift/expressiblebystringinterpolation diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 865dafc8..9383ffcd 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -8,29 +8,28 @@ import _ConnectionPoolModule /// A Postgres client that is backed by an underlying connection pool. Use ``Configuration`` to change the client's /// behavior. /// -/// > Warning: -/// The client can only lease connections if the user is running the client's ``run()`` method in a long running task: +/// ## Creating a client /// -/// ```swift -/// let client = PostgresClient(configuration: configuration) -/// await withTaskGroup(of: Void.self) { -/// taskGroup.addTask { -/// client.run() // !important -/// } +/// You create a ``PostgresClient`` by first creating a ``PostgresClient/Configuration`` struct that you can +/// use to modify the client's behavior. /// -/// do { -/// let rows = try await connection.query("SELECT userID, name, age FROM users;") -/// for try await (userID, name, age) in rows.decode((UUID, String, Int).self) { -/// // do something with the values -/// } -/// } catch { -/// // handle errors -/// } -/// -/// // shutdown the client -/// taskGroup.cancelAll() -/// } -/// ``` +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "configuration") +/// +/// Now you can create a client with your configuration object: +/// +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "makeClient") +/// +/// ## Running a client +/// +/// ``PostgresClient`` relies on structured concurrency. Because of this it needs a task in which it can schedule all the +/// background work that it needs to do in order to manage connections on the users behave. For this reason, developers +/// must provide a task to the client by scheduling the client's run method in a long running task: +/// +/// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "run") +/// +/// ``PostgresClient`` can not lease connections, if its ``run()`` method isn't active. Cancelling the ``run()`` method +/// is equivalent to closing the client. Once a client's ``run()`` method has been cancelled, executing queries or prepared +/// statements will fail. @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) public final class PostgresClient: Sendable, ServiceLifecycle.Service { public struct Configuration: Sendable { @@ -247,7 +246,9 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let backgroundLogger: Logger /// Creates a new ``PostgresClient``, that does not log any background information. - /// Don't forget to run ``run()`` the client in a long running task. + /// + /// > Warning: + /// The client can only lease connections if the user is running the client's ``run()`` method in a long running task. /// /// - Parameters: /// - configuration: The client's configuration. See ``Configuration`` for details. @@ -399,10 +400,21 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { } } - /// The client's run method. Users must call this function in order to start the client's background task processing - /// like creating and destroying connections and running timers. + /// The structured root task for the client's background work. + /// + /// > Warning: + /// Users must call this function in order to allow the client to process any background work. Executing queries, + /// prepared statements or leasing connections will hang until the developer executes the client's ``run()`` + /// method. + /// + /// Cancelling the task which executes the ``run()`` method, is equivalent to closing the client. Once the task + /// has been cancelled the client is not able to process any new queries or prepared statements. + /// + /// @Snippet(path: "postgres-nio/Snippets/PostgresClient", slice: "run") /// - /// Calls to ``withConnection(_:)`` will emit a `logger` warning, if ``run()`` hasn't been called previously. + /// > Note: + /// ``PostgresClient`` implements [ServiceLifecycle](https://github.com/swift-server/swift-service-lifecycle)'s `Service` protocol. Because of this + /// ``PostgresClient`` can be passed to a `ServiceGroup` for easier lifecycle management. public func run() async { let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") From b6496eb211a0d5c225bcc6d3ff4f26c2dd4238de Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 8 Mar 2024 11:52:04 -0600 Subject: [PATCH 229/292] Fix multiple array type mapping mistakes and add missing date and time array types (#463) * Add missing definitions for Postgres type OIDs 1182 and 1183 (_date and _time), fix typos in the `macaddr8Array` and `datemultirange` types, and add missing array mappings for `timestamp` and `tstzrange`. * Add PostgresArrayCodable conformance for Date * Add tests for date arrays. * Fix test to account for rounding error in conversion to days during Postgres encoding --- .../PostgresNIO/Data/PostgresDataType.swift | 30 +++++++++++---- .../New/Data/Array+PostgresCodable.swift | 7 ++++ Tests/IntegrationTests/PostgresNIOTests.swift | 38 +++++++++++++++++++ .../New/Data/Array+PSQLCodableTests.swift | 4 ++ 4 files changed, 72 insertions(+), 7 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresDataType.swift b/Sources/PostgresNIO/Data/PostgresDataType.swift index f3ab4dca..c3e4e747 100644 --- a/Sources/PostgresNIO/Data/PostgresDataType.swift +++ b/Sources/PostgresNIO/Data/PostgresDataType.swift @@ -113,12 +113,14 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri /// `774` public static let macaddr8 = PostgresDataType(774) /// `775` - public static let macaddr8Aray = PostgresDataType(775) + @available(*, deprecated, renamed: "macaddr8Array") + public static let macaddr8Aray = Self.macaddr8Array + public static let macaddr8Array = PostgresDataType(775) /// `790` public static let money = PostgresDataType(790) /// `791` @available(*, deprecated, renamed: "moneyArray") - public static let _money = PostgresDataType(791) + public static let _money = Self.moneyArray public static let moneyArray = PostgresDataType(791) /// `829` public static let macaddr = PostgresDataType(829) @@ -192,6 +194,10 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri public static let timestamp = PostgresDataType(1114) /// `1115` _timestamp public static let timestampArray = PostgresDataType(1115) + /// `1182` + public static let dateArray = PostgresDataType(1182) + /// `1183` + public static let timeArray = PostgresDataType(1183) /// `1184` public static let timestamptz = PostgresDataType(1184) /// `1185` @@ -446,7 +452,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .circle: return "CIRCLE" case .circleArray: return "CIRCLE[]" case .macaddr8: return "MACADDR8" - case .macaddr8Aray: return "MACADDR8[]" + case .macaddr8Array: return "MACADDR8[]" case .money: return "MONEY" case .moneyArray: return "MONEY[]" case .macaddr: return "MACADDR" @@ -485,6 +491,8 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .time: return "TIME" case .timestamp: return "TIMESTAMP" case .timestampArray: return "TIMESTAMP[]" + case .dateArray: return "DATE[]" + case .timeArray: return "TIME[]" case .timestamptz: return "TIMESTAMPTZ" case .timestamptzArray: return "TIMESTAMPTZ[]" case .interval: return "INTERVAL" @@ -596,7 +604,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .line: return .lineArray case .cidr: return .cidrArray case .circle: return .circleArray - case .macaddr8Aray: return .macaddr8 + case .macaddr8: return .macaddr8Array case .money: return .moneyArray case .int2vector: return .int2vectorArray case .regproc: return .regprocArray @@ -613,6 +621,9 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .aclitem: return .aclitemArray case .macaddr: return .macaddrArray case .inet: return .inetArray + case .timestamp: return .timestampArray + case .date: return .dateArray + case .time: return .timeArray case .timestamptz: return .timestamptzArray case .interval: return .intervalArray case .numeric: return .numericArray @@ -635,6 +646,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .regdictionary: return .regdictionaryArray case .numrange: return .numrangeArray case .tsrange: return .tsrangeArray + case .tstzrange: return .tstzrangeArray case .daterange: return .daterangeArray case .jsonpath: return .jsonpathArray case .regnamespace: return .regnamespaceArray @@ -643,7 +655,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .int4multirange: return .int4multirangeArray case .tsmultirange: return .tsmultirangeArray case .tstzmultirange: return .tstzmultirangeArray - case .datemultirange: return .datemultirange + case .datemultirange: return .datemultirangeArray case .int8multirange: return .int8multirangeArray case .bool: return .boolArray case .bytea: return .byteaArray @@ -677,7 +689,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .lineArray: return .line case .cidrArray: return .cidr case .circleArray: return .circle - case .macaddr8: return .macaddr8Aray + case .macaddr8Array: return .macaddr8 case .moneyArray: return .money case .int2vectorArray: return .int2vector case .regprocArray: return .regproc @@ -694,6 +706,9 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .aclitemArray: return .aclitem case .macaddrArray: return .macaddr case .inetArray: return .inet + case .timestampArray: return .timestamp + case .dateArray: return .date + case .timeArray: return .time case .timestamptzArray: return .timestamptz case .intervalArray: return .interval case .numericArray: return .numeric @@ -716,6 +731,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .regdictionaryArray: return .regdictionary case .numrangeArray: return .numrange case .tsrangeArray: return .tsrange + case .tstzrangeArray: return .tstzrange case .daterangeArray: return .daterange case .jsonpathArray: return .jsonpath case .regnamespaceArray: return .regnamespace @@ -724,7 +740,7 @@ public struct PostgresDataType: RawRepresentable, Sendable, Hashable, CustomStri case .int4multirangeArray: return .int4multirange case .tsmultirangeArray: return .tsmultirange case .tstzmultirangeArray: return .tstzmultirange - case .datemultirange: return .datemultirange + case .datemultirangeArray: return .datemultirange case .int8multirangeArray: return .int8multirange case .boolArray: return .bool case .byteaArray: return .bytea diff --git a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift index d605a6c1..ddab0fff 100644 --- a/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PostgresCodable.swift @@ -1,4 +1,5 @@ import NIOCore +import struct Foundation.Date import struct Foundation.UUID // MARK: Protocols @@ -85,6 +86,12 @@ extension UUID: PostgresArrayEncodable { public static var psqlArrayType: PostgresDataType { .uuidArray } } +extension Date: PostgresArrayDecodable {} + +extension Date: PostgresArrayEncodable { + public static var psqlArrayType: PostgresDataType { .timestamptzArray } +} + extension Range: PostgresArrayDecodable where Bound: PostgresRangeArrayDecodable {} extension Range: PostgresArrayEncodable where Bound: PostgresRangeArrayEncodable { diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index ea4d8d05..de6aaf73 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -783,6 +783,44 @@ final class PostgresNIOTests: XCTestCase { XCTAssertEqual(row?[data: "array"].array(of: Int64?.self), [1, nil, 3]) } + @available(*, deprecated, message: "Testing deprecated functionality") + func testDateArraySerialize() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow( try conn?.close().wait() ) } + let date1 = Date(timeIntervalSince1970: 1704088800), + date2 = Date(timeIntervalSince1970: 1706767200), + date3 = Date(timeIntervalSince1970: 1709272800) + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query(""" + select + $1::timestamptz[] as array + """, [ + PostgresData(array: [date1, date2, date3]) + ]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual(row?[data: "array"].array(of: Date.self), [date1, date2, date3]) + } + + @available(*, deprecated, message: "Testing deprecated functionality") + func testDateArraySerializeAsPostgresDate() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + let date1 = Date(timeIntervalSince1970: 1704088800),//8766 + date2 = Date(timeIntervalSince1970: 1706767200),//8797 + date3 = Date(timeIntervalSince1970: 1709272800) //8826 + var data = PostgresData(array: [date1, date2, date3].map { Int32(($0.timeIntervalSince1970 - 946_684_800) / 86_400).postgresData }, elementType: .date) + data.type = .dateArray // N.B.: `.date` format is an Int32 count of days since psqlStartDate + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query("select $1::date[] as array", [data]).wait()) + let row = rows?.first?.makeRandomAccess() + XCTAssertEqual( + row?[data: "array"].array(of: Date.self)?.map { Int32((($0.timeIntervalSince1970 - 946_684_800) / 86_400).rounded(.toNearestOrAwayFromZero)) }, + [date1, date2, date3].map { Int32((($0.timeIntervalSince1970 - 946_684_800) / 86_400).rounded(.toNearestOrAwayFromZero)) } + ) + } + // https://github.com/vapor/postgres-nio/issues/143 func testEmptyStringFromNonNullColumn() { var conn: PostgresConnection? diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index 79d47c30..bfffef52 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -56,6 +56,10 @@ class Array_PSQLCodableTests: XCTestCase { XCTAssertEqual(UUID.psqlType, .uuid) XCTAssertEqual([UUID].psqlType, .uuidArray) + XCTAssertEqual(Date.psqlArrayType, .timestamptzArray) + XCTAssertEqual(Date.psqlType, .timestamptz) + XCTAssertEqual([Date].psqlType, .timestamptzArray) + XCTAssertEqual(Range.psqlArrayType, .int4RangeArray) XCTAssertEqual(Range.psqlType, .int4Range) XCTAssertEqual([Range].psqlType, .int4RangeArray) From 43929b0fa76dae1c3679ea6bea49737b1c94cf40 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Fri, 8 Mar 2024 11:59:32 -0600 Subject: [PATCH 230/292] Minor package cleanup (#464) * Disable CodeQL CI, since GitHub seems disinclined to fix their mistakes. * Fix a few very minor issues in the API docs and README. * Make LOG_LEVEL env actually work in tests * Update CI for Swift 5.10 release * We only need two macOS tests, not four --- .github/workflows/test.yml | 20 +++--- README.md | 6 +- .../PostgresNIO/Docs.docc/images/article.svg | 1 - .../Docs.docc/images/vapor-postgres-logo.svg | 60 ------------------ .../images/vapor-postgresnio-logo.svg | 21 +++++++ .../PostgresNIO/Docs.docc/theme-settings.json | 61 ++++++------------- Tests/IntegrationTests/PostgresNIOTests.swift | 8 ++- Tests/IntegrationTests/Utilities.swift | 22 +++---- 8 files changed, 67 insertions(+), 132 deletions(-) delete mode 100644 Sources/PostgresNIO/Docs.docc/images/article.svg delete mode 100644 Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg create mode 100644 Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3d1f44a4..49d2cef1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,10 +21,10 @@ jobs: - swift:5.7-jammy - swift:5.8-jammy - swift:5.9-jammy - - swiftlang/swift:nightly-5.10-jammy + - swift:5.10-jammy - swiftlang/swift:nightly-main-jammy include: - - swift-image: swift:5.9-jammy + - swift-image: swift:5.10-jammy code-coverage: true container: ${{ matrix.swift-image }} runs-on: ubuntu-latest @@ -63,7 +63,7 @@ jobs: - postgres-image: postgres:12 postgres-auth: trust container: - image: swift:5.9-jammy + image: swift:5.10-jammy volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: @@ -140,7 +140,12 @@ jobs: xcode-version: - '~14.3' - '~15.0' - runs-on: macos-13 + include: + - xcode-version: '~14.3' + macos-version: 'macos-13' + - xcode-version: '~15.0' + macos-version: 'macos-14' + runs-on: ${{ matrix.macos-version }} env: POSTGRES_HOSTNAME: 127.0.0.1 POSTGRES_USER: 'test_username' @@ -188,8 +193,9 @@ jobs: swift package diagnose-api-breaking-changes origin/main gh-codeql: + if: ${{ false }} runs-on: ubuntu-latest - container: swift:5.9-jammy + container: swift:jammy permissions: { actions: write, contents: read, security-events: write } steps: - name: Check out code @@ -197,10 +203,10 @@ jobs: - name: Mark repo safe in non-fake global config run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: swift - name: Perform build run: swift build - name: Run CodeQL analyze - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/README.md b/README.md index c2dc545e..9e7d4e3b 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,10 @@ Continuous Integration - Swift 5.7 + + Swift 5.7+ - - SSWG Incubation Level: Graduated + + SSWG Incubation Level: Graduated

diff --git a/Sources/PostgresNIO/Docs.docc/images/article.svg b/Sources/PostgresNIO/Docs.docc/images/article.svg deleted file mode 100644 index 3dc6a66c..00000000 --- a/Sources/PostgresNIO/Docs.docc/images/article.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg deleted file mode 100644 index 2b3fe0b1..00000000 --- a/Sources/PostgresNIO/Docs.docc/images/vapor-postgres-logo.svg +++ /dev/null @@ -1,60 +0,0 @@ - - - PostgresNIO - - - - - - - - - - - - - - - - - - diff --git a/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg b/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg new file mode 100644 index 00000000..a831189d --- /dev/null +++ b/Sources/PostgresNIO/Docs.docc/images/vapor-postgresnio-logo.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index a8042a54..dda76197 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -1,46 +1,21 @@ { - "theme": { - "aside": { - "border-radius": "6px", - "border-style": "double", - "border-width": "3px" - }, - "border-radius": "0", - "button": { - "border-radius": "16px", - "border-width": "1px", - "border-style": "solid" - }, - "code": { - "border-radius": "16px", - "border-width": "1px", - "border-style": "solid" - }, - "color": { - "fill": { - "dark": "rgb(0, 0, 0)", - "light": "rgb(255, 255, 255)" - }, - "psql-blue": "#336791", - "documentation-intro-fill": "radial-gradient(circle at top, var(--color-documentation-intro-accent) 30%, #000 100%)", - "documentation-intro-accent": "var(--color-psql-blue)", - "documentation-intro-accent-outer": { - "dark": "rgb(255, 255, 255)", - "light": "rgb(0, 0, 0)" - }, - "documentation-intro-accent-inner": { - "dark": "rgb(0, 0, 0)", - "light": "rgb(255, 255, 255)" - } - }, - "icons": { - "technology": "/postgresnio/images/vapor-postgres-logo.svg", - "article": "/postgresnio/images/article.svg" - } + "theme": { + "aside": { "border-radius": "6px", "border-style": "double", "border-width": "3px" }, + "border-radius": "0", + "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, + "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, + "color": { + "psqlnio": "#336791", + "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", + "documentation-intro-accent": "var(--color-psqlnio)", + "logo-base": { "dark": "#fff", "light": "#000" }, + "logo-shape": { "dark": "#000", "light": "#fff" }, + "fill": { "dark": "#000", "light": "#fff" } }, - "features": { - "quickNavigation": { - "enable": true - } - } + "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } + }, + "features": { + "quickNavigation": { "enable": true }, + "i18n": { "enable": true } + } } diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index de6aaf73..88df2519 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -9,12 +9,14 @@ import NIOSSL final class PostgresNIOTests: XCTestCase { private var group: EventLoopGroup! - private var eventLoop: EventLoop { self.group.next() } + override class func setUp() { + XCTAssertTrue(isLoggingConfigured) + } + override func setUpWithError() throws { try super.setUpWithError() - XCTAssertTrue(isLoggingConfigured) self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) } @@ -1475,7 +1477,7 @@ final class PostgresNIOTests: XCTestCase { let isLoggingConfigured: Bool = { LoggingSystem.bootstrap { label in var handler = StreamLogHandler.standardOutput(label: label) - handler.logLevel = env("LOG_LEVEL").flatMap { Logger.Level(rawValue: $0) } ?? .debug + handler.logLevel = env("LOG_LEVEL").flatMap { .init(rawValue: $0) } ?? .info return handler } return true diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index b1788110..001d9ee4 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -24,10 +24,8 @@ extension PostgresConnection { } } - static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - var logger = Logger(label: "postgres.connection.test") - logger.logLevel = logLevel - + static func test(on eventLoop: EventLoop) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") let config = PostgresConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, @@ -40,10 +38,8 @@ extension PostgresConnection { return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } - static func testUDS(on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - var logger = Logger(label: "postgres.connection.test") - logger.logLevel = logLevel - + static func testUDS(on eventLoop: EventLoop) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") let config = PostgresConnection.Configuration( unixSocketPath: env("POSTGRES_SOCKET") ?? "/tmp/.s.PGSQL.\(env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432)", username: env("POSTGRES_USER") ?? "test_username", @@ -54,10 +50,8 @@ extension PostgresConnection { return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } - static func testChannel(_ channel: Channel, on eventLoop: EventLoop, logLevel: Logger.Level = .info) -> EventLoopFuture { - var logger = Logger(label: "postgres.connection.test") - logger.logLevel = logLevel - + static func testChannel(_ channel: Channel, on eventLoop: EventLoop) -> EventLoopFuture { + let logger = Logger(label: "postgres.connection.test") let config = PostgresConnection.Configuration( establishedChannel: channel, username: env("POSTGRES_USER") ?? "test_username", @@ -71,9 +65,7 @@ extension PostgresConnection { extension Logger { static var psqlTest: Logger { - var logger = Logger(label: "psql.test") - logger.logLevel = .info - return logger + .init(label: "psql.test") } } From 6f0fc054babeed13850f9014e03ced7a1d714868 Mon Sep 17 00:00:00 2001 From: Jia-Han Wu Date: Sat, 9 Mar 2024 04:27:46 +0800 Subject: [PATCH 231/292] Fix `reverseChunked(by:)` Method Implementation (#465) --- Sources/PostgresNIO/Data/PostgresData+Numeric.swift | 10 ++-------- Tests/IntegrationTests/PostgresNIOTests.swift | 8 ++++++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/Sources/PostgresNIO/Data/PostgresData+Numeric.swift b/Sources/PostgresNIO/Data/PostgresData+Numeric.swift index 5e564d6d..e736a61c 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Numeric.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Numeric.swift @@ -268,16 +268,10 @@ private extension Collection { // splits the collection into chunks of the supplied size // if the collection is not evenly divisible, the first chunk will be smaller func reverseChunked(by maxSize: Int) -> [SubSequence] { - var lastDistance = 0 var chunkStartIndex = self.startIndex return stride(from: 0, to: self.count, by: maxSize).reversed().map { current in - let distance = (self.count - current) - lastDistance - lastDistance = distance - let chunkEndOffset = Swift.min( - self.distance(from: chunkStartIndex, to: self.endIndex), - distance - ) - let chunkEndIndex = self.index(chunkStartIndex, offsetBy: chunkEndOffset) + let distance = self.count - current + let chunkEndIndex = self.index(self.startIndex, offsetBy: distance) defer { chunkStartIndex = chunkEndIndex } return self[chunkStartIndex.. Date: Tue, 19 Mar 2024 02:27:23 -0500 Subject: [PATCH 232/292] Temporarily disable Thread Sanitizer in CI --- .github/workflows/test.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 49d2cef1..8c6c3897 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,11 +38,11 @@ jobs: swift --version - name: Check out package uses: actions/checkout@v4 - - name: Run unit tests with Thread Sanitizer + - name: Run unit tests env: CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' ${CODE_COVERAGE} - name: Submit code coverage if: ${{ matrix.code-coverage }} uses: vapor/swift-codecov-action@v0.2 @@ -139,11 +139,11 @@ jobs: - scram-sha-256 xcode-version: - '~14.3' - - '~15.0' + - '~15' include: - xcode-version: '~14.3' macos-version: 'macos-13' - - xcode-version: '~15.0' + - xcode-version: '~15' macos-version: 'macos-14' runs-on: ${{ matrix.macos-version }} env: @@ -175,7 +175,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Run all tests - run: swift test + run: swift test --sanitize=thread api-breakage: if: github.event_name == 'pull_request' From 8f8724e496a8f26c0c13ceaa347647ac7248d6fd Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 26 Mar 2024 04:55:55 -0500 Subject: [PATCH 233/292] Turn Thread Sanitizer back on in CI (Github-side issue has been fixed) --- .github/workflows/test.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8c6c3897..7373e17d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,11 +38,11 @@ jobs: swift --version - name: Check out package uses: actions/checkout@v4 - - name: Run unit tests + - name: Run unit tests with Thread Sanitizer env: CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' ${CODE_COVERAGE} + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} - name: Submit code coverage if: ${{ matrix.code-coverage }} uses: vapor/swift-codecov-action@v0.2 @@ -175,7 +175,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Run all tests - run: swift test --sanitize=thread + run: swift test api-breakage: if: github.event_name == 'pull_request' From 35587e988316ee42924d7bb72e5cb14735c75470 Mon Sep 17 00:00:00 2001 From: Jay Herron <30518755+NeedleInAJayStack@users.noreply.github.com> Date: Tue, 26 Mar 2024 04:35:26 -0700 Subject: [PATCH 234/292] Fixes `LISTEN` to quote channel name (#466) Co-authored-by: Fabian Fett --- .../New/PostgresChannelHandler.swift | 4 +-- Tests/IntegrationTests/AsyncTests.swift | 29 ++++++++++++------- .../New/PostgresConnectionTests.swift | 10 +++---- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 32dea4a5..53dbd8c9 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -594,7 +594,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makeStartListeningQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( - query: PostgresQuery(unsafeSQL: "LISTEN \(channel);"), + query: PostgresQuery(unsafeSQL: #"LISTEN "\#(channel)";"#), logger: self.logger, promise: promise ) @@ -642,7 +642,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makeUnlistenQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( - query: PostgresQuery(unsafeSQL: "UNLISTEN \(channel);"), + query: PostgresQuery(unsafeSQL: #"UNLISTEN "\#(channel)";"#), logger: self.logger, promise: promise ) diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 75e5b6ba..ce6fe027 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -225,25 +225,32 @@ final class AsyncPostgresConnectionTests: XCTestCase { } func testListenAndNotify() async throws { + let channelNames = [ + "foo", + "default" + ] + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - try await self.withTestConnection(on: eventLoop) { connection in - let stream = try await connection.listen("foo") - var iterator = stream.makeAsyncIterator() + for channelName in channelNames { + try await self.withTestConnection(on: eventLoop) { connection in + let stream = try await connection.listen(channelName) + var iterator = stream.makeAsyncIterator() - try await self.withTestConnection(on: eventLoop) { other in - try await other.query(#"NOTIFY foo, 'bar';"#, logger: .psqlTest) + try await self.withTestConnection(on: eventLoop) { other in + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'bar';"#, logger: .psqlTest) - try await other.query(#"NOTIFY foo, 'foo';"#, logger: .psqlTest) - } + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'foo';"#, logger: .psqlTest) + } - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "bar") + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "bar") - let second = try await iterator.next() - XCTAssertEqual(second?.payload, "foo") + let second = try await iterator.next() + XCTAssertEqual(second?.payload, "foo") + } } } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index f2cd96f8..fe94633a 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -51,7 +51,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -63,7 +63,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, "UNLISTEN foo;") + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -111,7 +111,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -124,7 +124,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, "UNLISTEN foo;") + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -160,7 +160,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) From e345cbb9cf6052b37b27c0c4f976134fc01dbe15 Mon Sep 17 00:00:00 2001 From: Jia-Han Wu Date: Tue, 26 Mar 2024 19:38:41 +0800 Subject: [PATCH 235/292] Fix broken link in README.md (#467) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9e7d4e3b..b6cecc2d 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ detailed look at all of the classes, structs, protocols, and more. ## Getting started -Interested in an example? We prepared a simple [Birthday example](/vapor/postgres-nio/tree/main/Snippets/Birthdays.swift) +Interested in an example? We prepared a simple [Birthday example](https://github.com/vapor/postgres-nio/blob/main/Snippets/Birthdays.swift) in the Snippets folder. #### Adding the dependency From ee5d5e159c9892df957e06ac9f1f357502270487 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Wed, 1 May 2024 09:23:50 +0100 Subject: [PATCH 236/292] Make `TLS.disable` a let instead of a var (#471) This currently emits a Sendable warning since a global var isn't Sendable safe. --- Sources/PostgresNIO/Pool/PostgresClient.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 9383ffcd..2116a51d 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -47,7 +47,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { } /// Do not try to create a TLS connection to the server. - public static var disable: Self = Self.init(.disable) + public static let disable: Self = Self.init(.disable) /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. /// If the server does not support TLS, create an insecure connection. From a48eebc4f9c83de18e608f5a096769427e1177b9 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Fri, 10 May 2024 01:33:20 +0330 Subject: [PATCH 237/292] Actually use additional connection parameters (#473) --- .../New/PostgresChannelHandler.swift | 3 +- Tests/IntegrationTests/AsyncTests.swift | 33 ++++++++++++++- Tests/IntegrationTests/Utilities.swift | 9 ++-- .../New/PostgresConnectionTests.swift | 42 +++++++++++++++++++ 4 files changed, 82 insertions(+), 5 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 53dbd8c9..a3190aa7 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -390,7 +390,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { let authContext = AuthContext( username: username, password: self.configuration.password, - database: self.configuration.database + database: self.configuration.database, + additionalParameters: self.configuration.options.additionalStartupParameters ) let action = self.state.provideAuthenticationContext(authContext) return self.run(action, with: context) diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index ce6fe027..513157fd 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -84,6 +84,36 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testAdditionalParametersTakeEffect() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let query: PostgresQuery = """ + SELECT + current_setting('application_name'); + """ + + let applicationName = "postgres-nio-test" + var options = PostgresConnection.Configuration.Options() + options.additionalStartupParameters = [ + ("application_name", applicationName) + ] + + try await withTestConnection(on: eventLoop, options: options) { connection in + let rows = try await connection.query(query, logger: .psqlTest) + var counter = 0 + + for try await element in rows.decode(String.self) { + XCTAssertEqual(element, applicationName) + + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + } + } + func testSelectTimeoutWhileLongRunningQuery() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -452,11 +482,12 @@ extension XCTestCase { func withTestConnection( on eventLoop: EventLoop, + options: PostgresConnection.Configuration.Options? = nil, file: StaticString = #filePath, line: UInt = #line, _ closure: (PostgresConnection) async throws -> Result ) async throws -> Result { - let connection = try await PostgresConnection.test(on: eventLoop).get() + let connection = try await PostgresConnection.test(on: eventLoop, options: options).get() do { let result = try await closure(connection) diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index 001d9ee4..91dbb62e 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -24,9 +24,9 @@ extension PostgresConnection { } } - static func test(on eventLoop: EventLoop) -> EventLoopFuture { + static func test(on eventLoop: EventLoop, options: Configuration.Options? = nil) -> EventLoopFuture { let logger = Logger(label: "postgres.connection.test") - let config = PostgresConnection.Configuration( + var config = PostgresConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, username: env("POSTGRES_USER") ?? "test_username", @@ -34,7 +34,10 @@ extension PostgresConnection { database: env("POSTGRES_DB") ?? "test_database", tls: .disable ) - + if let options { + config.options = options + } + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index fe94633a..34528f7e 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -38,6 +38,48 @@ class PostgresConnectionTests: XCTestCase { } } + func testOptionsAreSentOnTheWire() async throws { + let eventLoop = NIOAsyncTestingEventLoop() + let channel = await NIOAsyncTestingChannel(handlers: [ + ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), + ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), + ], loop: eventLoop) + try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) + + let configuration = { + var config = PostgresConnection.Configuration( + establishedChannel: channel, + username: "username", + password: "postgres", + database: "database" + ) + config.options.additionalStartupParameters = [ + ("DateStyle", "ISO, MDY"), + ("application_name", "postgres-nio-test"), + ("server_encoding", "UTF8"), + ("integer_datetimes", "on"), + ("client_encoding", "UTF8"), + ("TimeZone", "Etc/UTC"), + ("is_superuser", "on"), + ("server_version", "13.1 (Debian 13.1-1.pgdg100+1)"), + ("session_authorization", "postgres"), + ("IntervalStyle", "postgres"), + ("standard_conforming_strings", "on") + ] + return config + }() + + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) + let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) + try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + let connection = try await connectionPromise + try await connection.close() + } + func testSimpleListen() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() From e62cc88d244a075e0263b33edb54ef793cd5a1f8 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 28 May 2024 01:22:08 -0500 Subject: [PATCH 238/292] [CI] Update code coverage action, attempt fix for Homebrew nonsense (#476) --- .github/workflows/test.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7373e17d..808718fb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,6 +22,7 @@ jobs: - swift:5.8-jammy - swift:5.9-jammy - swift:5.10-jammy + - swiftlang/swift:nightly-6.0-jammy - swiftlang/swift:nightly-main-jammy include: - swift-image: swift:5.10-jammy @@ -45,7 +46,9 @@ jobs: swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} - name: Submit code coverage if: ${{ matrix.code-coverage }} - uses: vapor/swift-codecov-action@v0.2 + uses: vapor/swift-codecov-action@v0.3 + with: + codecov_token: ${{ secrets.CODECOV_TOKEN }} linux-integration-and-dependencies: strategy: @@ -165,7 +168,7 @@ jobs: # ** BEGIN ** Work around bug in both Homebrew and GHA (brew upgrade python@3.11 || true) && (brew link --force --overwrite python@3.11 || true) (brew upgrade python@3.12 || true) && (brew link --force --overwrite python@3.12 || true) - brew upgrade + (brew upgrade || true) # ** END ** Work around bug in both Homebrew and GHA brew install --overwrite "${POSTGRES_FORMULA}" brew link --overwrite --force "${POSTGRES_FORMULA}" From d3795844d488210b65ace34c5f003e47d812d999 Mon Sep 17 00:00:00 2001 From: Johannes Weiss Date: Wed, 29 May 2024 15:48:59 +0100 Subject: [PATCH 239/292] Workaround DiscardingTaskGroup non-conformance with nightly compilers (#478) --- .../ConnectionPoolModule/ConnectionPool.swift | 20 +++++++++++++------ .../ConnectionPoolTests.swift | 14 ++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 9f25e82c..8ba0e7be 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -419,7 +419,7 @@ public final class ConnectionPool< @inlinable /*private*/ func makeConnection(for request: StateMachine.ConnectionRequest, in taskGroup: inout some TaskGroupProtocol) { - taskGroup.addTask { + taskGroup.addTask_ { self.observabilityDelegate.startedConnecting(id: request.connectionID) do { @@ -468,7 +468,7 @@ public final class ConnectionPool< /*private*/ func runKeepAlive(_ connection: Connection, in taskGroup: inout some TaskGroupProtocol) { self.observabilityDelegate.keepAliveTriggered(id: connection.id) - taskGroup.addTask { + taskGroup.addTask_ { do { try await self.keepAliveBehavior.runKeepAlive(for: connection) @@ -503,7 +503,7 @@ public final class ConnectionPool< @inlinable /*private*/ func runTimer(_ timer: StateMachine.Timer, in poolGroup: inout some TaskGroupProtocol) { - poolGroup.addTask { () async -> () in + poolGroup.addTask_ { () async -> () in await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in taskGroup.addTask { do { @@ -587,17 +587,25 @@ extension AsyncStream { @usableFromInline protocol TaskGroupProtocol { - mutating func addTask(operation: @escaping @Sendable () async -> Void) + // We need to call this `addTask_` because some Swift versions define this + // under exactly this name and others have different attributes. So let's pick + // a name that doesn't clash anywhere and implement it using the standard `addTask`. + mutating func addTask_(operation: @escaping @Sendable () async -> Void) } #if swift(>=5.8) && os(Linux) || swift(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) -extension DiscardingTaskGroup: TaskGroupProtocol {} +extension DiscardingTaskGroup: TaskGroupProtocol { + @inlinable + mutating func addTask_(operation: @escaping @Sendable () async -> Void) { + self.addTask(priority: nil, operation: operation) + } +} #endif extension TaskGroup: TaskGroupProtocol { @inlinable - mutating func addTask(operation: @escaping @Sendable () async -> Void) { + mutating func addTask_(operation: @escaping @Sendable () async -> Void) { self.addTask(priority: nil, operation: operation) } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 3e3c9d65..3c0e7a6b 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -26,7 +26,7 @@ final class ConnectionPoolTests: XCTestCase { // the same connection is reused 1000 times await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } @@ -82,14 +82,14 @@ final class ConnectionPoolTests: XCTestCase { } await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } let (blockCancelStream, blockCancelContinuation) = AsyncStream.makeStream(of: Void.self) let (blockConnCreationStream, blockConnCreationContinuation) = AsyncStream.makeStream(of: Void.self) - taskGroup.addTask { + taskGroup.addTask_ { _ = try? await factory.nextConnectAttempt { _ in blockCancelContinuation.yield() var iterator = blockConnCreationStream.makeAsyncIterator() @@ -127,7 +127,7 @@ final class ConnectionPoolTests: XCTestCase { } await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } @@ -170,12 +170,12 @@ final class ConnectionPoolTests: XCTestCase { // the same connection is reused 1000 times await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original) } - taskGroup.addTask { + taskGroup.addTask_ { var usedConnectionIDs = Set() for _ in 0.. Date: Thu, 30 May 2024 13:31:54 +0200 Subject: [PATCH 240/292] Fix crash when recreating minimal connections (#480) --- .../PoolStateMachine+ConnectionGroup.swift | 2 +- Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 833365fa..f26f244d 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -592,7 +592,7 @@ extension PoolStateMachine { let newConnectionRequest: ConnectionRequest? if self.connections.count < self.minimumConcurrentConnections { - newConnectionRequest = .init(connectionID: self.generator.next()) + newConnectionRequest = self.createNewConnection() } else { newConnectionRequest = .none } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index f5ada14f..2f3ae617 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -375,6 +375,10 @@ final class PoolStateMachineTests: XCTestCase { let connectionClosed = stateMachine.connectionClosed(connection) XCTAssertEqual(connectionClosed.connection, .makeConnection(.init(connectionID: 1), [])) connection.closeIfClosing() + let establishAction = stateMachine.connectionEstablished(.init(id: 1), maxStreams: 1) + XCTAssertEqual(establishAction.request, .none) + guard case .scheduleTimers(let timers) = establishAction.connection else { return XCTFail("Unexpected connection action") } + XCTAssertEqual(timers, [.init(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: configuration.keepAliveDuration!)]) } } From 5c268768890b062803a49f1358becc478f954265 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 13 Jun 2024 18:54:18 +0200 Subject: [PATCH 241/292] Fix totally unnecessary `preconditionFailure` in `PSQLEventsHandler` (#481) --- Sources/PostgresNIO/New/PSQLEventsHandler.swift | 4 +--- Tests/PostgresNIOTests/New/PostgresConnectionTests.swift | 9 +++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift index 2bf0d6d8..0f426f20 100644 --- a/Sources/PostgresNIO/New/PSQLEventsHandler.swift +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -68,10 +68,8 @@ final class PSQLEventsHandler: ChannelInboundHandler { case .authenticated: break } - case TLSUserEvent.shutdownCompleted: - break default: - preconditionFailure() + context.fireUserInboundEventTriggered(event) } } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 34528f7e..209522dd 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -416,6 +416,15 @@ class PostgresConnectionTests: XCTestCase { } } + func testWeDontCrashOnUnexpectedChannelEvents() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + enum MyEvent { + case pleaseDontCrash + } + channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + } + func testSerialExecutionOfSamePreparedStatement() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() From e7b9a08a11c0a4eedafb8032f13cfa764ae45b13 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 13 Jun 2024 12:18:17 -0500 Subject: [PATCH 242/292] [CI] Use Ubuntu 24.04 image, more code coverage, disable CodeQL completely (#482) * [CI] Use Ubuntu 24.04 image for Swift 5.10, upload code coverage more often, completely disable CodeQL * Add CODEOWNERS --- .github/CODEOWNERS | 1 + .github/workflows/test.yml | 50 +++++++++++++++++--------------------- 2 files changed, 23 insertions(+), 28 deletions(-) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..6413432f --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @fabianfett @gwynne diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 808718fb..f74427c3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,12 +21,9 @@ jobs: - swift:5.7-jammy - swift:5.8-jammy - swift:5.9-jammy - - swift:5.10-jammy + - swift:5.10-noble - swiftlang/swift:nightly-6.0-jammy - swiftlang/swift:nightly-main-jammy - include: - - swift-image: swift:5.10-jammy - code-coverage: true container: ${{ matrix.swift-image }} runs-on: ubuntu-latest steps: @@ -40,12 +37,9 @@ jobs: - name: Check out package uses: actions/checkout@v4 - name: Run unit tests with Thread Sanitizer - env: - CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread --enable-code-coverage - name: Submit code coverage - if: ${{ matrix.code-coverage }} uses: vapor/swift-codecov-action@v0.3 with: codecov_token: ${{ secrets.CODECOV_TOKEN }} @@ -66,7 +60,7 @@ jobs: - postgres-image: postgres:12 postgres-auth: trust container: - image: swift:5.10-jammy + image: swift:5.10-noble volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: @@ -183,7 +177,7 @@ jobs: api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest - container: swift:jammy + container: swift:noble steps: - name: Checkout uses: actions/checkout@v4 @@ -195,21 +189,21 @@ jobs: git config --global --add safe.directory "${GITHUB_WORKSPACE}" swift package diagnose-api-breaking-changes origin/main - gh-codeql: - if: ${{ false }} - runs-on: ubuntu-latest - container: swift:jammy - permissions: { actions: write, contents: read, security-events: write } - steps: - - name: Check out code - uses: actions/checkout@v4 - - name: Mark repo safe in non-fake global config - run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: swift - - name: Perform build - run: swift build - - name: Run CodeQL analyze - uses: github/codeql-action/analyze@v3 +# gh-codeql: +# if: ${{ false }} +# runs-on: ubuntu-latest +# container: swift:noble +# permissions: { actions: write, contents: read, security-events: write } +# steps: +# - name: Check out code +# uses: actions/checkout@v4 +# - name: Mark repo safe in non-fake global config +# run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" +# - name: Initialize CodeQL +# uses: github/codeql-action/init@v3 +# with: +# languages: swift +# - name: Perform build +# run: swift build +# - name: Run CodeQL analyze +# uses: github/codeql-action/analyze@v3 From 5f541d05970a4fad5accb54365191f1f8e91ea3e Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 14 Jun 2024 11:49:36 +0200 Subject: [PATCH 243/292] Drop support for Swift 5.7 (#485) --- .github/workflows/test.yml | 1 - Package.swift | 2 +- README.md | 4 ++-- Sources/ConnectionPoolModule/ConnectionPool.swift | 8 ++++---- Sources/PostgresNIO/New/NotificationListener.swift | 2 +- .../New/PostgresNotificationSequence.swift | 7 +------ Sources/PostgresNIO/Utilities/Exports.swift | 11 ----------- 7 files changed, 9 insertions(+), 26 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f74427c3..1761880d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,6 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.7-jammy - swift:5.8-jammy - swift:5.9-jammy - swift:5.10-noble diff --git a/Package.swift b/Package.swift index 4d008371..79c740f9 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.7 +// swift-tools-version:5.8 import PackageDescription let package = Package( diff --git a/README.md b/README.md index b6cecc2d..bc56953b 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Continuous Integration - Swift 5.7+ + Swift 5.8+ SSWG Incubation Level: Graduated @@ -167,7 +167,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.7]: https://swift.org +[Swift 5.8]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 8ba0e7be..3231cc06 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -273,7 +273,7 @@ public final class ConnectionPool< public func run() async { await withTaskCancellationHandler { - #if swift(>=5.8) && os(Linux) || swift(>=5.9) + #if os(Linux) || compiler(>=5.9) if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { return await withDiscardingTaskGroup() { taskGroup in await self.run(in: &taskGroup) @@ -313,7 +313,7 @@ public final class ConnectionPool< case scheduleTimer(StateMachine.Timer) } - #if swift(>=5.8) && os(Linux) || swift(>=5.9) + #if os(Linux) || compiler(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) private func run(in taskGroup: inout DiscardingTaskGroup) async { for await event in self.eventStream { @@ -507,7 +507,7 @@ public final class ConnectionPool< await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in taskGroup.addTask { do { - #if swift(>=5.8) && os(Linux) || swift(>=5.9) + #if os(Linux) || compiler(>=5.9) try await self.clock.sleep(for: timer.duration) #else try await self.clock.sleep(until: self.clock.now.advanced(by: timer.duration), tolerance: nil) @@ -593,7 +593,7 @@ protocol TaskGroupProtocol { mutating func addTask_(operation: @escaping @Sendable () async -> Void) } -#if swift(>=5.8) && os(Linux) || swift(>=5.9) +#if os(Linux) || swift(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) extension DiscardingTaskGroup: TaskGroupProtocol { @inlinable diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift index 9e47ff34..4982b8ad 100644 --- a/Sources/PostgresNIO/New/NotificationListener.swift +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -142,7 +142,7 @@ final class NotificationListener: @unchecked Sendable { } -#if swift(<5.9) +#if compiler(<5.9) // Async stream API backfill extension AsyncThrowingStream { static func makeStream( diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift index 55fb0670..735c01b0 100644 --- a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -3,7 +3,7 @@ public struct PostgresNotification: Sendable { public let payload: String } -public struct PostgresNotificationSequence: AsyncSequence { +public struct PostgresNotificationSequence: AsyncSequence, Sendable { public typealias Element = PostgresNotification let base: AsyncThrowingStream @@ -20,8 +20,3 @@ public struct PostgresNotificationSequence: AsyncSequence { } } } - -#if swift(>=5.7) -// AsyncThrowingStream is marked as Sendable in Swift 5.6 -extension PostgresNotificationSequence: Sendable {} -#endif diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 58e12891..144ff3c9 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,14 +1,3 @@ -#if swift(>=5.8) - @_documentation(visibility: internal) @_exported import NIO @_documentation(visibility: internal) @_exported import NIOSSL @_documentation(visibility: internal) @_exported import struct Logging.Logger - -#else - -// TODO: Remove this with the next major release! -@_exported import NIO -@_exported import NIOSSL -@_exported import struct Logging.Logger - -#endif From 6c3d0a938d248965da42d451f619cf74f0fff882 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 14 Jun 2024 11:53:39 +0200 Subject: [PATCH 244/292] Update ServiceLifecycle to 2.5.0 (#484) --- Package.swift | 2 +- Sources/PostgresNIO/Pool/PostgresClient.swift | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Package.swift b/Package.swift index 79c740f9..d24ee979 100644 --- a/Package.swift +++ b/Package.swift @@ -22,7 +22,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), - .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.4.1"), + .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.5.0"), ], targets: [ .target( diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 2116a51d..2e1b7e11 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -419,7 +419,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") - await cancelOnGracefulShutdown { + await cancelWhenGracefulShutdown { await self.pool.run() } } From 7b621c16f6a0a8a0af8badd56b6f980457a1b7c5 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 14 Jun 2024 13:32:56 +0200 Subject: [PATCH 245/292] Enable StrictConcurrency checking (#483) --- Package.swift | 19 ++++-- .../ConnectionPoolModule/ConnectionPool.swift | 4 +- .../ConnectionPoolObservabilityDelegate.swift | 2 +- .../Message/PostgresMessage+Identifier.swift | 2 +- Sources/PostgresNIO/Pool/PostgresClient.swift | 2 +- .../Utilities/PostgresError+Code.swift | 2 +- .../Mocks/MockConnectionFactory.swift | 2 +- Tests/IntegrationTests/PostgresNIOTests.swift | 61 ++++++++++--------- .../New/PostgresConnectionTests.swift | 19 +++--- 9 files changed, 63 insertions(+), 50 deletions(-) diff --git a/Package.swift b/Package.swift index d24ee979..0683dbe9 100644 --- a/Package.swift +++ b/Package.swift @@ -1,6 +1,10 @@ // swift-tools-version:5.8 import PackageDescription +let swiftSettings: [SwiftSetting] = [ + .enableUpcomingFeature("StrictConcurrency") +] + let package = Package( name: "postgres-nio", platforms: [ @@ -41,7 +45,8 @@ let package = Package( .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), - ] + ], + swiftSettings: swiftSettings ), .target( name: "_ConnectionPoolModule", @@ -49,7 +54,8 @@ let package = Package( .product(name: "Atomics", package: "swift-atomics"), .product(name: "DequeModule", package: "swift-collections"), ], - path: "Sources/ConnectionPoolModule" + path: "Sources/ConnectionPoolModule", + swiftSettings: swiftSettings ), .testTarget( name: "PostgresNIOTests", @@ -57,7 +63,8 @@ let package = Package( .target(name: "PostgresNIO"), .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "NIOTestUtils", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), .testTarget( name: "ConnectionPoolModuleTests", @@ -67,14 +74,16 @@ let package = Package( .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOEmbedded", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), .testTarget( name: "IntegrationTests", dependencies: [ .target(name: "PostgresNIO"), .product(name: "NIOTestUtils", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), ] ) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 3231cc06..03c269ee 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -1,6 +1,6 @@ @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -public struct ConnectionAndMetadata { +public struct ConnectionAndMetadata: Sendable { public var connection: Connection @@ -495,7 +495,7 @@ public final class ConnectionPool< } @usableFromInline - enum TimerRunResult { + enum TimerRunResult: Sendable { case timerTriggered case timerCancelled case cancellationContinuationFinished diff --git a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift index 35f30dcb..fc1e300c 100644 --- a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift +++ b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift @@ -37,7 +37,7 @@ public protocol ConnectionPoolObservabilityDelegate: Sendable { func requestQueueDepthChanged(_ newDepth: Int) } -public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { +public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { public init(connectionIDType: ConnectionID.Type) {} public func startedConnecting(id: ConnectionID) {} diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift index 786b91ef..5d111e3b 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift @@ -4,7 +4,7 @@ extension PostgresMessage { /// Identifies an incoming or outgoing postgres message. Sent as the first byte, before the message size. /// Values are not unique across all identifiers, meaning some messages will require keeping state to identify. @available(*, deprecated, message: "Will be removed from public API.") - public struct Identifier: ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { + public struct Identifier: Sendable, ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { // special public static let none: Identifier = 0x00 // special diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 2e1b7e11..0907f1f8 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -478,7 +478,7 @@ extension PostgresConnection: PooledConnection { self.channel.close(mode: .all, promise: nil) } - public func onClose(_ closure: @escaping ((any Error)?) -> ()) { + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { self.closeFuture.whenComplete { _ in closure(nil) } } } diff --git a/Sources/PostgresNIO/Utilities/PostgresError+Code.swift b/Sources/PostgresNIO/Utilities/PostgresError+Code.swift index 11224f4b..fae903fe 100644 --- a/Sources/PostgresNIO/Utilities/PostgresError+Code.swift +++ b/Sources/PostgresNIO/Utilities/PostgresError+Code.swift @@ -1,5 +1,5 @@ extension PostgresError { - public struct Code: ExpressibleByStringLiteral, Equatable { + public struct Code: Sendable, ExpressibleByStringLiteral, Equatable { // Class 00 — Successful Completion public static let successfulCompletion: Code = "00000" diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift index eec2e7c3..1c9bfff8 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift +++ b/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift @@ -2,7 +2,7 @@ import DequeModule @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockConnectionFactory where Clock.Duration == Duration { +final class MockConnectionFactory: Sendable where Clock.Duration == Duration { typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator typealias Request = ConnectionRequest typealias KeepAliveBehavior = MockPingPongBehavior diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 4d06c13e..ff59209b 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1,5 +1,6 @@ import Logging @testable import PostgresNIO +import Atomics import XCTest import NIOCore import NIOPosix @@ -112,59 +113,59 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications: [PostgresMessage.NotificationResponse] = [] + let receivedNotifications = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications.append(notification) + receivedNotifications.wrappingIncrement(ordering: .relaxed) + XCTAssertEqual(notification.channel, "example") + XCTAssertEqual(notification.payload, "") } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) // Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications.count, 1) - XCTAssertEqual(receivedNotifications.first?.channel, "example") - XCTAssertEqual(receivedNotifications.first?.payload, "") + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsNonEmptyPayload() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications: [PostgresMessage.NotificationResponse] = [] + let receivedNotifications = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications.append(notification) + receivedNotifications.wrappingIncrement(ordering: .relaxed) + XCTAssertEqual(notification.channel, "example") + XCTAssertEqual(notification.payload, "Notification payload example") } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example, 'Notification payload example'").wait()) // Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications.count, 1) - XCTAssertEqual(receivedNotifications.first?.channel, "example") - XCTAssertEqual(receivedNotifications.first?.payload, "Notification payload example") + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsRemoveHandlerWithinHandler() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications = 0 + let receivedNotifications = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications += 1 + receivedNotifications.wrappingIncrement(ordering: .relaxed) context.stop() } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications, 1) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsRemoveHandlerOutsideHandler() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications = 0 + let receivedNotifications = ManagedAtomic(0) let context = conn?.addListener(channel: "example") { context, notification in - receivedNotifications += 1 + receivedNotifications.wrappingIncrement(ordering: .relaxed) } XCTAssertNotNil(context) XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) @@ -173,47 +174,47 @@ final class PostgresNIOTests: XCTestCase { context?.stop() XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications, 1) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsMultipleRegisteredHandlers() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications1 = 0 + let receivedNotifications1 = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications1 += 1 + receivedNotifications1.wrappingIncrement(ordering: .relaxed) } - var receivedNotifications2 = 0 + let receivedNotifications2 = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications2 += 1 + receivedNotifications2.wrappingIncrement(ordering: .relaxed) } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications1, 1) - XCTAssertEqual(receivedNotifications2, 1) + XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1) + XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 1) } func testNotificationsMultipleRegisteredHandlersRemoval() throws { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications1 = 0 + let receivedNotifications1 = ManagedAtomic(0) XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in - receivedNotifications1 += 1 + receivedNotifications1.wrappingIncrement(ordering: .relaxed) context.stop() }) - var receivedNotifications2 = 0 + let receivedNotifications2 = ManagedAtomic(0) XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in - receivedNotifications2 += 1 + receivedNotifications2.wrappingIncrement(ordering: .relaxed) }) XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications1, 1) - XCTAssertEqual(receivedNotifications2, 2) + XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1) + XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 2) } func testNotificationHandlerFiltersOnChannel() { @@ -1283,11 +1284,11 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } var queries: [[PostgresRow]]? - XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { query in + XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { [eventLoop] query in let a = query.execute(["a"]) let b = query.execute(["b"]) let c = query.execute(["c"]) - return EventLoopFuture.whenAllSucceed([a, b, c], on: self.eventLoop) + return EventLoopFuture.whenAllSucceed([a, b, c], on: eventLoop) }).wait()) XCTAssertEqual(queries?.count, 3) var resultIterator = queries?.makeIterator() diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 209522dd..0bc61efd 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -187,7 +187,7 @@ class PostgresConnectionTests: XCTestCase { func testSimpleListenConnectionDrops() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in taskGroup.addTask { let events = try await connection.listen("foo") var iterator = events.makeAsyncIterator() @@ -197,7 +197,7 @@ class PostgresConnectionTests: XCTestCase { _ = try await iterator.next() XCTFail("Did not expect to not throw") } catch { - self.logger.error("error", metadata: ["error": "\(error)"]) + logger.error("error", metadata: ["error": "\(error)"]) } } @@ -226,10 +226,10 @@ class PostgresConnectionTests: XCTestCase { func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in for _ in 1...2 { taskGroup.addTask { - let rows = try await connection.query("SELECT 1;", logger: self.logger) + let rows = try await connection.query("SELECT 1;", logger: logger) var iterator = rows.decode(Int.self).makeAsyncIterator() let first = try await iterator.next() XCTAssertEqual(first, 1) @@ -286,10 +286,10 @@ class PostgresConnectionTests: XCTestCase { func testCloseClosesImmediatly() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in for _ in 1...2 { taskGroup.addTask { - try await connection.query("SELECT 1;", logger: self.logger) + try await connection.query("SELECT 1;", logger: logger) } } @@ -319,8 +319,9 @@ class PostgresConnectionTests: XCTestCase { func testIfServerJustClosesTheErrorReflectsThat() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + let logger = self.logger - async let response = try await connection.query("SELECT 1;", logger: self.logger) + async let response = try await connection.query("SELECT 1;", logger: logger) let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") @@ -423,6 +424,7 @@ class PostgresConnectionTests: XCTestCase { case pleaseDontCrash } channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + try await connection.close() } func testSerialExecutionOfSamePreparedStatement() async throws { @@ -651,7 +653,8 @@ class PostgresConnectionTests: XCTestCase { database: "database" ) - async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger) + let logger = self.logger + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) From f55caa7745a43357f7af7dfdd0300955dbd8c6a3 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Mon, 24 Jun 2024 15:15:22 +0330 Subject: [PATCH 246/292] [Fix] Query Hangs if Connection is Closed (#487) --- .../Connection/PostgresConnection.swift | 39 ++-- .../PSQLIntegrationTests.swift | 1 - .../New/PostgresConnectionTests.swift | 169 ++++++++++++++++++ 3 files changed, 197 insertions(+), 12 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index eb9dc791..a6efcfdf 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) return promise.futureResult } @@ -239,7 +239,8 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) + return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -255,7 +256,8 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) + return promise.futureResult } @@ -263,7 +265,8 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.channel.write(HandlerTask.closeCommand(context), promise: nil) + self.write(.closeCommand(context), cascadingFailureTo: promise) + return promise.futureResult } @@ -426,7 +429,7 @@ extension PostgresConnection { promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.write(.extendedQuery(context), cascadingFailureTo: promise) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -455,7 +458,11 @@ extension PostgresConnection { let task = HandlerTask.startListening(listener) - self.channel.write(task, promise: nil) + let writePromise = self.channel.eventLoop.makePromise(of: Void.self) + self.channel.write(task, promise: writePromise) + writePromise.futureResult.whenFailure { error in + listener.failed(error) + } } } onCancel: { let task = HandlerTask.cancelListening(channel, id) @@ -480,7 +487,9 @@ extension PostgresConnection { logger: logger, promise: promise )) - self.channel.write(task, promise: nil) + + self.write(task, cascadingFailureTo: promise) + do { return try await promise.futureResult .map { $0.asyncSequence() } @@ -515,7 +524,9 @@ extension PostgresConnection { logger: logger, promise: promise )) - self.channel.write(task, promise: nil) + + self.write(task, cascadingFailureTo: promise) + do { return try await promise.futureResult .map { $0.commandTag } @@ -530,6 +541,12 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + private func write(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise) { + let writePromise = self.channel.eventLoop.makePromise(of: Void.self) + self.channel.write(task, promise: writePromise) + writePromise.futureResult.cascadeFailure(to: promise) + } } // MARK: EventLoopFuture interface @@ -674,7 +691,7 @@ internal enum PostgresCommands: PostgresRequest { /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. public final class PostgresListenContext: Sendable { - private let promise: EventLoopPromise + let promise: EventLoopPromise var future: EventLoopFuture { self.promise.futureResult @@ -713,8 +730,7 @@ extension PostgresConnection { closure: notificationHandler ) - let task = HandlerTask.startListening(listener) - self.channel.write(task, promise: nil) + self.write(.startListening(listener), cascadingFailureTo: listenContext.promise) listenContext.future.whenComplete { _ in let task = HandlerTask.cancelListening(channel, id) @@ -761,3 +777,4 @@ extension PostgresConnection { #endif } } + diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 57939c06..913d91b2 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -359,5 +359,4 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(obj?.bar, 2) } } - } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 0bc61efd..5c7d4c83 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -224,6 +224,63 @@ class PostgresConnectionTests: XCTestCase { } } + func testSimpleListenFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.listen("test_channel") + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + var iterator = events.makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "wooohooo") + do { + _ = try await iterator.next() + XCTFail("Did not expect to not throw") + } catch let error as PSQLError { + XCTAssertEqual(error.code, .clientClosedConnection) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + + try await connection.close() + + XCTAssertEqual(channel.isActive, false) + + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") + } + } + } + func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -638,6 +695,118 @@ class PostgresConnectionTests: XCTestCase { } } + func testQueryFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.query("SELECT version;", logger: self.logger) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testPrepareStatementFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + _ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get() + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecuteFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + do { + let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil) + _ = try await connection.execute(statement, logger: .psqlTest).get() + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + struct TestPreparedStatement: PostgresPreparedStatement { + static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + typealias Row = (Int, String) + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + do { + let preparedStatement = TestPreparedStatement(state: "active") + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + + func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await connection.closeGracefully() + + XCTAssertEqual(channel.isActive, false) + + struct TestPreparedStatement: PostgresPreparedStatement { + static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1" + typealias Row = () + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + do { + let preparedStatement = TestPreparedStatement(state: "active") + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Expected to fail") + } catch let error as ChannelError { + XCTAssertEqual(error, .ioOnClosedChannel) + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [ From 200a94a13381f2cbc2c4f5303da777997a80937d Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 27 Jun 2024 17:59:03 +0200 Subject: [PATCH 247/292] Explicitly mark the AsyncSequence iterators as non Sendable (#490) --- Package.swift | 2 +- Sources/PostgresNIO/New/PostgresNotificationSequence.swift | 3 +++ Sources/PostgresNIO/New/PostgresQuery.swift | 2 +- Sources/PostgresNIO/New/PostgresRowSequence.swift | 5 ++++- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Package.swift b/Package.swift index 0683dbe9..5c83eded 100644 --- a/Package.swift +++ b/Package.swift @@ -2,7 +2,7 @@ import PackageDescription let swiftSettings: [SwiftSetting] = [ - .enableUpcomingFeature("StrictConcurrency") + .enableUpcomingFeature("StrictConcurrency"), ] let package = Package( diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift index 735c01b0..d8f525eb 100644 --- a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -20,3 +20,6 @@ public struct PostgresNotificationSequence: AsyncSequence, Sendable { } } } + +@available(*, unavailable) +extension PostgresNotificationSequence.AsyncIterator: Sendable {} diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 1cfcf2dc..b695dcfe 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -26,7 +26,7 @@ extension PostgresQuery: ExpressibleByStringInterpolation { } extension PostgresQuery { - public struct StringInterpolation: StringInterpolationProtocol { + public struct StringInterpolation: StringInterpolationProtocol, Sendable { public typealias StringLiteralType = String @usableFromInline diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index ccf4f69c..3936b51e 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -4,7 +4,7 @@ import NIOConcurrencyHelpers /// An async sequence of ``PostgresRow``s. /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. -public struct PostgresRowSequence: AsyncSequence { +public struct PostgresRowSequence: AsyncSequence, Sendable { public typealias Element = PostgresRow typealias BackingSequence = NIOThrowingAsyncSequenceProducer @@ -56,6 +56,9 @@ extension PostgresRowSequence { } } +@available(*, unavailable) +extension PostgresRowSequence.AsyncIterator: Sendable {} + extension PostgresRowSequence { public func collect() async throws -> [PostgresRow] { var result = [PostgresRow]() From d18b137640222fe29a22568077c4799d213fdf96 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Thu, 25 Jul 2024 09:56:51 +0100 Subject: [PATCH 248/292] Change 'unsafeDowncast' to 'as!' (#495) Motivation: The 'unsafeDowncast' can cause a miscompile leading to unexpected runtime behaviour. Modifications: - Use 'as!' instead Result: No miscompiles on 5.10 --- Sources/ConnectionPoolModule/NIOLock.swift | 29 +++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift index dbc7dbe9..13a9df4a 100644 --- a/Sources/ConnectionPoolModule/NIOLock.swift +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -52,12 +52,12 @@ extension LockOperations { debugOnly { pthread_mutexattr_settype(&attr, .init(PTHREAD_MUTEX_ERRORCHECK)) } - + let err = pthread_mutex_init(mutex, &attr) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") #endif } - + @inlinable static func destroy(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() @@ -69,7 +69,7 @@ extension LockOperations { precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") #endif } - + @inlinable static func lock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() @@ -81,7 +81,7 @@ extension LockOperations { precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") #endif } - + @inlinable static func unlock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() @@ -125,49 +125,50 @@ extension LockOperations { // See also: https://github.com/apple/swift/pull/40000 @usableFromInline final class LockStorage: ManagedBuffer { - + @inlinable static func create(value: Value) -> Self { let buffer = Self.create(minimumCapacity: 1) { _ in return value } - let storage = unsafeDowncast(buffer, to: Self.self) - + // Avoid 'unsafeDowncast' as there is a miscompilation on 5.10. + let storage = buffer as! Self + storage.withUnsafeMutablePointers { _, lockPtr in LockOperations.create(lockPtr) } - + return storage } - + @inlinable func lock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.lock(lockPtr) } } - + @inlinable func unlock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.unlock(lockPtr) } } - + @inlinable deinit { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.destroy(lockPtr) } } - + @inlinable func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { try self.withUnsafeMutablePointerToElements { lockPtr in return try body(lockPtr) } } - + @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { try self.withUnsafeMutablePointers { valuePtr, lockPtr in @@ -192,7 +193,7 @@ extension LockStorage: @unchecked Sendable { } struct NIOLock { @usableFromInline internal let _storage: LockStorage - + /// Create a new lock. @inlinable init() { From cd5318a01a1efcb1e0b3c82a0ce5c9fefaf1cb2d Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 20 Aug 2024 14:04:28 +0200 Subject: [PATCH 249/292] Revert "[Fix] Query Hangs if Connection is Closed (#487)" (#501) This reverts commit f55caa7745a43357f7af7dfdd0300955dbd8c6a3. --- .../Connection/PostgresConnection.swift | 39 ++-- .../PSQLIntegrationTests.swift | 1 + .../New/PostgresConnectionTests.swift | 169 ------------------ 3 files changed, 12 insertions(+), 197 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index a6efcfdf..eb9dc791 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -239,8 +239,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -256,8 +255,7 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.write(.extendedQuery(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -265,8 +263,7 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.write(.closeCommand(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.closeCommand(context), promise: nil) return promise.futureResult } @@ -429,7 +426,7 @@ extension PostgresConnection { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -458,11 +455,7 @@ extension PostgresConnection { let task = HandlerTask.startListening(listener) - let writePromise = self.channel.eventLoop.makePromise(of: Void.self) - self.channel.write(task, promise: writePromise) - writePromise.futureResult.whenFailure { error in - listener.failed(error) - } + self.channel.write(task, promise: nil) } } onCancel: { let task = HandlerTask.cancelListening(channel, id) @@ -487,9 +480,7 @@ extension PostgresConnection { logger: logger, promise: promise )) - - self.write(task, cascadingFailureTo: promise) - + self.channel.write(task, promise: nil) do { return try await promise.futureResult .map { $0.asyncSequence() } @@ -524,9 +515,7 @@ extension PostgresConnection { logger: logger, promise: promise )) - - self.write(task, cascadingFailureTo: promise) - + self.channel.write(task, promise: nil) do { return try await promise.futureResult .map { $0.commandTag } @@ -541,12 +530,6 @@ extension PostgresConnection { throw error // rethrow with more metadata } } - - private func write(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise) { - let writePromise = self.channel.eventLoop.makePromise(of: Void.self) - self.channel.write(task, promise: writePromise) - writePromise.futureResult.cascadeFailure(to: promise) - } } // MARK: EventLoopFuture interface @@ -691,7 +674,7 @@ internal enum PostgresCommands: PostgresRequest { /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. public final class PostgresListenContext: Sendable { - let promise: EventLoopPromise + private let promise: EventLoopPromise var future: EventLoopFuture { self.promise.futureResult @@ -730,7 +713,8 @@ extension PostgresConnection { closure: notificationHandler ) - self.write(.startListening(listener), cascadingFailureTo: listenContext.promise) + let task = HandlerTask.startListening(listener) + self.channel.write(task, promise: nil) listenContext.future.whenComplete { _ in let task = HandlerTask.cancelListening(channel, id) @@ -777,4 +761,3 @@ extension PostgresConnection { #endif } } - diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 913d91b2..57939c06 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -359,4 +359,5 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(obj?.bar, 2) } } + } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 5c7d4c83..0bc61efd 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -224,63 +224,6 @@ class PostgresConnectionTests: XCTestCase { } } - func testSimpleListenFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.listen("test_channel") - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - var iterator = events.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "wooohooo") - do { - _ = try await iterator.next() - XCTFail("Did not expect to not throw") - } catch let error as PSQLError { - XCTAssertEqual(error.code, .clientClosedConnection) - } - } - - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - - try await connection.close() - - XCTAssertEqual(channel.isActive, false) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } - } - } - func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -695,118 +638,6 @@ class PostgresConnectionTests: XCTestCase { } } - func testQueryFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.query("SELECT version;", logger: self.logger) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testPrepareStatementFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get() - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecuteFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil) - _ = try await connection.execute(statement, logger: .psqlTest).get() - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - struct TestPreparedStatement: PostgresPreparedStatement { - static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" - typealias Row = (Int, String) - - var state: String - - func makeBindings() -> PostgresBindings { - var bindings = PostgresBindings() - bindings.append(self.state) - return bindings - } - - func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { - try row.decode(Row.self) - } - } - - do { - let preparedStatement = TestPreparedStatement(state: "active") - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - struct TestPreparedStatement: PostgresPreparedStatement { - static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1" - typealias Row = () - - var state: String - - func makeBindings() -> PostgresBindings { - var bindings = PostgresBindings() - bindings.append(self.state) - return bindings - } - - func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { - () - } - } - - do { - let preparedStatement = TestPreparedStatement(state: "active") - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [ From 3de37e6438d018159a9c3ef1ea0ca154039ce480 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Wed, 21 Aug 2024 16:15:31 +0330 Subject: [PATCH 250/292] Handle `EmptyQueryResponse` (#500) --- .../ExtendedQueryStateMachine.swift | 36 ++++++++-- Sources/PostgresNIO/New/PSQLRowStream.swift | 66 +++++++++++-------- .../New/PostgresChannelHandler.swift | 4 +- .../PostgresNIO/PostgresDatabase+Query.swift | 5 +- .../PSQLIntegrationTests.swift | 19 ++++++ .../ExtendedQueryStateMachineTests.swift | 22 ++++++- .../PreparedStatementStateMachineTests.swift | 8 +-- .../New/PSQLRowStreamTests.swift | 2 +- 8 files changed, 114 insertions(+), 48 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 78f0d202..087a6c24 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -10,7 +10,8 @@ struct ExtendedQueryStateMachine { case parameterDescriptionReceived(ExtendedQueryContext) case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column]) case noDataMessageReceived(ExtendedQueryContext) - + case emptyQueryResponseReceived + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) @@ -122,7 +123,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(.queryCancelled, read: true) } - case .commandComplete, .error, .drain: + case .commandComplete, .emptyQueryResponseReceived, .error, .drain: // the stream has already finished. return .wait @@ -229,6 +230,7 @@ struct ExtendedQueryStateMachine { .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, + .emptyQueryResponseReceived, .bindCompleteReceived, .streaming, .drain, @@ -268,6 +270,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, @@ -285,7 +288,7 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return self.avoidingStateMachineCoW { state -> Action in state = .commandComplete(commandTag: commandTag) - let result = QueryResult(value: .noRows(commandTag), logger: context.logger) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) return .succeedQuery(eventLoopPromise, with: result) } @@ -309,6 +312,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .commandComplete, .error: @@ -319,7 +323,22 @@ struct ExtendedQueryStateMachine { } mutating func emptyQueryResponseReceived() -> Action { - preconditionFailure("Unimplemented") + guard case .bindCompleteReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), + .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement(_, _, _, _): + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } } mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { @@ -336,7 +355,7 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .streaming, .drain: return self.setAndFireError(error) - case .commandComplete: + case .commandComplete, .emptyQueryResponseReceived: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) case .error: preconditionFailure(""" @@ -382,6 +401,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: preconditionFailure("Requested to consume next row without anything going on.") @@ -405,6 +425,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: return .wait @@ -449,6 +470,7 @@ struct ExtendedQueryStateMachine { } case .initialized, .commandComplete, + .emptyQueryResponseReceived, .drain, .error: // we already have the complete stream received, now we are waiting for a @@ -495,7 +517,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(error, read: true) } - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: preconditionFailure(""" This state must not be reached. If the query `.isComplete`, the ConnectionStateMachine must not send any further events to the substate machine. @@ -507,7 +529,7 @@ struct ExtendedQueryStateMachine { var isComplete: Bool { switch self.state { - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: return true case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _): diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b7f2d4fb..ee925d0e 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -3,7 +3,7 @@ import Logging struct QueryResult { enum Value: Equatable { - case noRows(String) + case noRows(PSQLRowStream.StatementSummary) case rowDescription([RowDescription.Column]) } @@ -16,25 +16,30 @@ struct QueryResult { final class PSQLRowStream: @unchecked Sendable { private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source + enum StatementSummary: Equatable { + case tag(String) + case emptyResponse + } + enum Source { case stream([RowDescription.Column], PSQLRowsDataSource) - case noRows(Result) + case noRows(Result) } let eventLoop: EventLoop let logger: Logger - + private enum BufferState { case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) - case finished(buffer: CircularBuffer, commandTag: String) + case finished(buffer: CircularBuffer, summary: StatementSummary) case failure(Error) } - + private enum DownstreamState { case waitingForConsumer(BufferState) case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) - case consumed(Result) + case consumed(Result) case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) } @@ -52,9 +57,9 @@ final class PSQLRowStream: @unchecked Sendable { case .stream(let rowDescription, let dataSource): self.rowDescription = rowDescription bufferState = .streaming(buffer: .init(), dataSource: dataSource) - case .noRows(.success(let commandTag)): + case .noRows(.success(let summary)): self.rowDescription = [] - bufferState = .finished(buffer: .init(), commandTag: commandTag) + bufferState = .finished(buffer: .init(), summary: summary) case .noRows(.failure(let error)): self.rowDescription = [] bufferState = .failure(error) @@ -98,12 +103,12 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish) self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): _ = source.yield(contentsOf: buffer) source.finish() onFinish() - self.downstreamState = .consumed(.success(commandTag)) - + self.downstreamState = .consumed(.success(summary)) + case .failure(let error): source.finish(error) self.downstreamState = .consumed(.failure(error)) @@ -190,12 +195,12 @@ final class PSQLRowStream: @unchecked Sendable { dataSource.request(for: self) return promise.futureResult - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): let rows = buffer.map { PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededFuture(rows) case .failure(let error): @@ -247,8 +252,8 @@ final class PSQLRowStream: @unchecked Sendable { } return promise.futureResult - - case .finished(let buffer, let commandTag): + + case .finished(let buffer, let summary): do { for data in buffer { let row = PostgresRow( @@ -259,7 +264,7 @@ final class PSQLRowStream: @unchecked Sendable { try onRow(row) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededVoidFuture() } catch { self.downstreamState = .consumed(.failure(error)) @@ -292,7 +297,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.finished), .waitingForConsumer(.failure): preconditionFailure("How can new rows be received, if an end was already signalled?") - + case .iteratingRows(let onRow, let promise, let dataSource): do { for data in newRows { @@ -347,25 +352,25 @@ final class PSQLRowStream: @unchecked Sendable { private func receiveEnd(_ commandTag: String) { switch self.downstreamState { case .waitingForConsumer(.streaming(buffer: let buffer, _)): - self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag)) - - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, summary: .tag(commandTag))) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(()) case .waitingForAll(let rows, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(rows) case .asyncSequence(let source, _, let onFinish): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) source.finish() onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -375,7 +380,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.streaming): self.downstreamState = .waitingForConsumer(.failure(error)) - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): @@ -391,7 +396,7 @@ final class PSQLRowStream: @unchecked Sendable { consumer.finish(error) onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -413,10 +418,15 @@ final class PSQLRowStream: @unchecked Sendable { } var commandTag: String { - guard case .consumed(.success(let commandTag)) = self.downstreamState else { + guard case .consumed(.success(let consumed)) = self.downstreamState else { preconditionFailure("commandTag may only be called if all rows have been consumed") } - return commandTag + switch consumed { + case .tag(let tag): + return tag + case .emptyResponse: + return "" + } } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index a3190aa7..ee2af0fe 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -550,9 +550,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) self.rowStream = rows - case .noRows(let commandTag): + case .noRows(let summary): rows = PSQLRowStream( - source: .noRows(.success(commandTag)), + source: .noRows(.success(summary)), eventLoop: context.channel.eventLoop, logger: result.logger ) diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 01a7e61f..483d5a7b 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -73,10 +73,7 @@ public struct PostgresQueryMetadata: Sendable { init?(string: String) { let parts = string.split(separator: " ") - guard parts.count >= 1 else { - return nil - } - switch parts[0] { + switch parts.first { case "INSERT": // INSERT oid rows guard parts.count == 3 else { diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 57939c06..d541899b 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -123,6 +123,25 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(foo, "hello") } + func testQueryNothing() throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var _result: PostgresQueryResult? + XCTAssertNoThrow(_result = try conn?.query(""" + -- Some comments + """, logger: .psqlTest).wait()) + + let result = try XCTUnwrap(_result) + XCTAssertEqual(result.rows, []) + XCTAssertEqual(result.metadata.command, "") + } + func testDecodeIntegers() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 40e32468..ae484acc 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -20,7 +20,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), .wait) XCTAssertEqual(state.bindCompleteReceived(), .wait) - XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows("DELETE 1"), logger: logger))) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } @@ -77,7 +77,25 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } - + + func testExtendedQueryWithNoQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "-- some comments" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + XCTAssertEqual(state.noDataReceived(), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .wait) + XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + func testReceiveTotallyUnexpectedMessageInQuery() { var state = ConnectionStateMachine.readyForQuery() diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift index f6c1ddf7..e35e93f7 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -28,7 +28,7 @@ class PreparedStatementStateMachineTests: XCTestCase { XCTAssertEqual(preparationCompleteAction.statements.count, 1) XCTAssertNil(preparationCompleteAction.rowDescription) firstPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) @@ -46,7 +46,7 @@ class PreparedStatementStateMachineTests: XCTestCase { return } secondPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) @@ -135,12 +135,12 @@ class PreparedStatementStateMachineTests: XCTestCase { XCTAssertNil(preparationCompleteAction.rowDescription) firstPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) secondPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 9a1e9e41..65ca26c3 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -12,7 +12,7 @@ final class PSQLRowStreamTests: XCTestCase { func testEmptyStream() { let stream = PSQLRowStream( - source: .noRows(.success("INSERT 0 1")), + source: .noRows(.success(.tag("INSERT 0 1"))), eventLoop: self.eventLoop, logger: self.logger ) From 9f84290f4f7ba3b3edb749d196243fc2df6b82e6 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Thu, 22 Aug 2024 22:22:00 +0330 Subject: [PATCH 251/292] Fix Flaky Nightly Tests (#503) --- Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 769bde4b..3f406598 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -14,7 +14,7 @@ class Date_PSQLCodableTests: XCTestCase { var result: Date? XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) - XCTAssertEqual(value, result) + XCTAssertEqual(value.timeIntervalSince1970, result?.timeIntervalSince1970 ?? 0, accuracy: 0.001) } func testDecodeRandomDate() { From 8f7e9002462c1a625e590e568fe31251a2429c8a Mon Sep 17 00:00:00 2001 From: Lei Nelissen Date: Wed, 25 Sep 2024 16:48:33 +0200 Subject: [PATCH 252/292] Fix cross-compilation to the static Linux SDK (#510) --- Sources/ConnectionPoolModule/PoolStateMachine.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 3b996033..6e41f730 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -1,7 +1,9 @@ #if canImport(Darwin) import Darwin -#else +#elseif canImport(Glibc) import Glibc +#elseif canImport(Musl) +import Musl #endif @usableFromInline From c13a11a97b9878cdc1366b4adf03c03cea0b6163 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Wed, 9 Oct 2024 03:33:36 -0500 Subject: [PATCH 253/292] Drop Swift 5.8 support and update CI (#515) --- .github/workflows/test.yml | 16 ++++++---------- Package.swift | 2 +- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1761880d..8364e8ae 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,10 +18,9 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.8-jammy - swift:5.9-jammy - swift:5.10-noble - - swiftlang/swift:nightly-6.0-jammy + - swift:6.0-noble - swiftlang/swift:nightly-main-jammy container: ${{ matrix.swift-image }} runs-on: ubuntu-latest @@ -48,13 +47,13 @@ jobs: fail-fast: false matrix: postgres-image: - - postgres:16 - - postgres:14 + - postgres:17 + - postgres:15 - postgres:12 include: - - postgres-image: postgres:16 + - postgres-image: postgres:17 postgres-auth: scram-sha-256 - - postgres-image: postgres:14 + - postgres-image: postgres:15 postgres-auth: md5 - postgres-image: postgres:12 postgres-auth: trust @@ -134,11 +133,8 @@ jobs: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 xcode-version: - - '~14.3' - '~15' include: - - xcode-version: '~14.3' - macos-version: 'macos-13' - xcode-version: '~15' macos-version: 'macos-14' runs-on: ${{ matrix.macos-version }} @@ -172,7 +168,7 @@ jobs: uses: actions/checkout@v4 - name: Run all tests run: swift test - + api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest diff --git a/Package.swift b/Package.swift index 5c83eded..5f6562f6 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.8 +// swift-tools-version:5.9 import PackageDescription let swiftSettings: [SwiftSetting] = [ From 225c5c4adaf48e69fec20321187843c75dada65d Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 9 Oct 2024 10:43:37 +0200 Subject: [PATCH 254/292] Remove all code that solely existed to support Swift 5.8 (#516) --- .../ConnectionPoolModule/ConnectionPool.swift | 14 - .../New/NotificationListener.swift | 16 - .../New/PostgresRow-multi-decode.swift | 1175 ----------------- .../PostgresRowSequence-multi-decode.swift | 215 --- .../PostgresNIO/New/VariadicGenerics.swift | 4 +- 5 files changed, 1 insertion(+), 1423 deletions(-) delete mode 100644 Sources/PostgresNIO/New/PostgresRow-multi-decode.swift delete mode 100644 Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 03c269ee..5cdb980d 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -571,20 +571,6 @@ extension PoolConfiguration { } } -#if swift(<5.9) -// This should be removed once we support Swift 5.9+ only -extension AsyncStream { - static func makeStream( - of elementType: Element.Type = Element.self, - bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded - ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { - var continuation: AsyncStream.Continuation! - let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } -} -#endif - @usableFromInline protocol TaskGroupProtocol { // We need to call this `addTask_` because some Swift versions define this diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift index 4982b8ad..2f784e33 100644 --- a/Sources/PostgresNIO/New/NotificationListener.swift +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -140,19 +140,3 @@ final class NotificationListener: @unchecked Sendable { } } } - - -#if compiler(<5.9) -// Async stream API backfill -extension AsyncThrowingStream { - static func makeStream( - of elementType: Element.Type = Element.self, - throwing failureType: Failure.Type = Failure.self, - bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded - ) -> (stream: AsyncThrowingStream, continuation: AsyncThrowingStream.Continuation) where Failure == Error { - var continuation: AsyncThrowingStream.Continuation! - let stream = AsyncThrowingStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } - } - #endif diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift deleted file mode 100644 index 71aa04dc..00000000 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ /dev/null @@ -1,1175 +0,0 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh - -#if compiler(<5.9) -extension PostgresRow { - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0) { - precondition(self.columns.count >= 1) - let columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - let column = columnIterator.next().unsafelyUnwrapped - let swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) throws -> (T0) { - try self.decode(T0.self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { - precondition(self.columns.count >= 2) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { - try self.decode((T0, T1).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { - precondition(self.columns.count >= 3) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { - try self.decode((T0, T1, T2).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { - precondition(self.columns.count >= 4) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { - try self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { - precondition(self.columns.count >= 5) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { - try self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { - precondition(self.columns.count >= 6) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { - try self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { - precondition(self.columns.count >= 7) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { - try self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { - precondition(self.columns.count >= 8) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { - precondition(self.columns.count >= 9) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { - precondition(self.columns.count >= 10) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { - precondition(self.columns.count >= 11) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { - precondition(self.columns.count >= 12) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { - precondition(self.columns.count >= 13) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { - precondition(self.columns.count >= 14) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 13 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T13.self - let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { - precondition(self.columns.count >= 15) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 13 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T13.self - let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 14 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T14.self - let r14 = try T14._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) - } -} -#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift deleted file mode 100644 index f45357d8..00000000 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ /dev/null @@ -1,215 +0,0 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh - -#if compiler(<5.9) -extension AsyncSequence where Element == PostgresRow { - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode(T0.self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode(T0.self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) - } -} -#endif diff --git a/Sources/PostgresNIO/New/VariadicGenerics.swift b/Sources/PostgresNIO/New/VariadicGenerics.swift index 312d36dc..7931c90c 100644 --- a/Sources/PostgresNIO/New/VariadicGenerics.swift +++ b/Sources/PostgresNIO/New/VariadicGenerics.swift @@ -1,4 +1,4 @@ -#if compiler(>=5.9) + extension PostgresRow { // --- snip TODO: Remove once bug is fixed, that disallows tuples of one @inlinable @@ -170,5 +170,3 @@ enum ComputeParameterPackLength { MemoryLayout<(repeat BoolConverter.Bool)>.size / MemoryLayout.stride } } -#endif // compiler(>=5.9) - From d4c2f38ff5b5bdce6fd952ee75670631c4c8b5a4 Mon Sep 17 00:00:00 2001 From: Robert Cottrell Date: Mon, 21 Oct 2024 08:19:08 +0100 Subject: [PATCH 255/292] Allow bindings with optional values in PostgresBindings (#520) --- Sources/PostgresNIO/New/PostgresQuery.swift | 46 ++++++++++++ Tests/IntegrationTests/AsyncTests.swift | 81 +++++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index b695dcfe..6449ab29 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -172,6 +172,16 @@ public struct PostgresBindings: Sendable, Hashable { try self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -181,11 +191,34 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value, context: context) + } + } + @inlinable public mutating func append(_ value: Value) { self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -195,6 +228,19 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value, context: context) + } + } + @inlinable mutating func appendUnprotected( _ value: Value, diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 513157fd..b4c8e93f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -476,6 +476,87 @@ final class AsyncPostgresConnectionTests: XCTestCase { XCTFail("Unexpected error: \(String(describing: error))") } } + + static let preparedStatementWithOptionalTestTable = "AsyncTestPreparedStatementWithOptionalTestTable" + func testPreparedStatementWithOptionalBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID? + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID?) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID?).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementWithOptionalTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(String(describing: uuid))" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementWithOptionalTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } } extension XCTestCase { From f2a6394a2e7157d547727b975fc0328b92f89fb1 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 21 Oct 2024 10:57:36 +0200 Subject: [PATCH 256/292] Support `additionalStartupParameters` in PostgresClient (#521) --- .../PostgresNIO/Pool/ConnectionFactory.swift | 1 + Sources/PostgresNIO/Pool/PostgresClient.swift | 4 +++ .../PostgresClientTests.swift | 35 +++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/Sources/PostgresNIO/Pool/ConnectionFactory.swift b/Sources/PostgresNIO/Pool/ConnectionFactory.swift index 77a0c047..319b86c4 100644 --- a/Sources/PostgresNIO/Pool/ConnectionFactory.swift +++ b/Sources/PostgresNIO/Pool/ConnectionFactory.swift @@ -89,6 +89,7 @@ final class ConnectionFactory: Sendable { connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout) connectionConfig.options.tlsServerName = config.options.tlsServerName connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData + connectionConfig.options.additionalStartupParameters = config.options.additionalStartupParameters return connectionConfig } diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 0907f1f8..ad8a4bf1 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -106,6 +106,10 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). public var requireBackendKeyData: Bool = true + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] = [] + /// The minimum number of connections that the client shall keep open at any time, even if there is no /// demand. Default to `0`. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index d6d89dc3..579c92cd 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -43,6 +43,41 @@ final class PostgresClientTests: XCTestCase { } } + func testApplicationNameIsForwardedCorrectly() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + var clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let applicationName = "postgres_nio_test_run" + clientConfig.options.additionalStartupParameters = [("application_name", applicationName)] + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + let rows = try await client.query("select * from pg_stat_activity;"); + var applicationNameFound = 0 + for try await row in rows { + let randomAccessRow = row.makeRandomAccess() + if try randomAccessRow["application_name"].decode(String?.self) == applicationName { + applicationNameFound += 1 + } + } + + XCTAssertGreaterThanOrEqual(applicationNameFound, 1) + + taskGroup.cancelAll() + } + } + + func testQueryDirectly() async throws { var mlogger = Logger(label: "test") mlogger.logLevel = .debug From 96ed89ff0dc457a2533bed80d4cf2a87976bc296 Mon Sep 17 00:00:00 2001 From: Thomas Krajacic Date: Sun, 8 Dec 2024 23:04:18 +0100 Subject: [PATCH 257/292] Correctly place the SSL channel handler in front of the PostgresChannelHandler (#527) --- Sources/PostgresNIO/Connection/PostgresConnection.swift | 6 +++--- Sources/PostgresNIO/New/PostgresChannelHandler.swift | 8 ++++---- .../New/PostgresChannelHandlerTests.swift | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index eb9dc791..229cd647 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -60,18 +60,18 @@ public final class PostgresConnection: @unchecked Sendable { func start(configuration: InternalConfiguration) -> EventLoopFuture { // 1. configure handlers - let configureSSLCallback: ((Channel) throws -> ())? + let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> ())? switch configuration.tls.base { case .prefer(let context), .require(let context): - configureSSLCallback = { channel in + configureSSLCallback = { channel, postgresChannelHandler in channel.eventLoop.assertInEventLoop() let sslHandler = try NIOSSLClientHandler( context: context, serverHostname: configuration.serverNameForTLS ) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(postgresChannelHandler)) } case .disable: configureSSLCallback = nil diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index ee2af0fe..0a14849a 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -20,7 +20,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private var decoder: NIOSingleStepByteToMessageProcessor private var encoder: PostgresFrontendMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration - private let configureSSLCallback: ((Channel) throws -> Void)? + private let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? private var listenState = ListenStateMachine() private var preparedStatementState = PreparedStatementStateMachine() @@ -29,7 +29,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { configuration: PostgresConnection.InternalConfiguration, eventLoop: EventLoop, logger: Logger, - configureSSLCallback: ((Channel) throws -> Void)? + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop @@ -46,7 +46,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { eventLoop: EventLoop, state: ConnectionStateMachine = .init(.initialized), logger: Logger = .psqlNoOpLogger, - configureSSLCallback: ((Channel) throws -> Void)? + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? ) { self.state = state self.eventLoop = eventLoop @@ -439,7 +439,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // This method must only be called, if we signalized the StateMachine before that we are // able to setup a SSL connection. do { - try self.configureSSLCallback!(context.channel) + try self.configureSSLCallback!(context.channel, self) let action = self.state.sslHandlerAdded() self.run(action, with: context) } catch { diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index dfdcc53e..a2c90969 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -48,7 +48,7 @@ class PostgresChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in addSSLCallbackIsHit = true } let embedded = EmbeddedChannel(handlers: [ @@ -84,7 +84,7 @@ class PostgresChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in addSSLCallbackIsHit = true } let eventHandler = TestEventHandler() @@ -114,7 +114,7 @@ class PostgresChannelHandlerTests: XCTestCase { func testSSLUnsupportedClosesConnection() throws { let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in XCTFail("This callback should never be exectuded") throw PSQLError.sslUnsupported } From fd0e415a705c490499f983639b04f491a2ed9d99 Mon Sep 17 00:00:00 2001 From: Thomas Krajacic Date: Tue, 10 Dec 2024 10:11:53 +0100 Subject: [PATCH 258/292] Allow TLS enabled connections when providing an established channel (#526) Co-authored-by: Fabian Fett --- .../PostgresConnection+Configuration.swift | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index dd0f5404..b260723a 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -192,9 +192,22 @@ extension PostgresConnection { /// - Parameters: /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). - /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + /// - tls: The TLS mode to use. + public init(establishedChannel channel: Channel, tls: PostgresConnection.Configuration.TLS, username: String, password: String?, database: String?) { + self.init(endpointInfo: .configureChannel(channel), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for establishing a connection to a Postgres server over a preestablished + /// `NIOCore/Channel`. + /// + /// This is provided for calling code which wants to manage the underlying connection transport on its + /// own, such as when tunneling a connection through SSH. + /// + /// - Parameters: + /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an + /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). public init(establishedChannel channel: Channel, username: String, password: String?, database: String?) { - self.init(endpointInfo: .configureChannel(channel), tls: .disable, username: username, password: password, database: database) + self.init(establishedChannel: channel, tls: .disable, username: username, password: password, database: database) } // MARK: - Implementation details From 045cc49fbe224093cc1d77e79065e9e00081d119 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sat, 21 Dec 2024 04:57:06 -0600 Subject: [PATCH 259/292] Update DocC settings to latest version of Vapor theme (#529) Update DocC settings to latest version of Vapor theme, for compatibility with Swift 6's DocC changes --- Sources/PostgresNIO/Docs.docc/theme-settings.json | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index dda76197..911cc1bc 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -1,16 +1,19 @@ { "theme": { - "aside": { "border-radius": "6px", "border-style": "double", "border-width": "3px" }, + "aside": { "border-radius": "16px", "border-style": "double", "border-width": "3px" }, "border-radius": "0", "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "color": { + "fill": { "dark": "#000", "light": "#fff" } "psqlnio": "#336791", "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", "documentation-intro-accent": "var(--color-psqlnio)", + "documentation-intro-eyebrow": "white", + "documentation-intro-figure": "white", + "documentation-intro-title": "white", "logo-base": { "dark": "#fff", "light": "#000" }, "logo-shape": { "dark": "#000", "light": "#fff" }, - "fill": { "dark": "#000", "light": "#fff" } }, "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } }, From 7c29718fe5631462417ed3350ccc1e131678bf13 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Sat, 21 Dec 2024 05:06:54 -0600 Subject: [PATCH 260/292] Fix malformed JSON in theme settings (#530) Fix malformed JSON in theme settings due to comma misplacement --- Sources/PostgresNIO/Docs.docc/theme-settings.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index 911cc1bc..38914a04 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -5,7 +5,7 @@ "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "color": { - "fill": { "dark": "#000", "light": "#fff" } + "fill": { "dark": "#000", "light": "#fff" }, "psqlnio": "#336791", "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", "documentation-intro-accent": "var(--color-psqlnio)", @@ -13,7 +13,7 @@ "documentation-intro-figure": "white", "documentation-intro-title": "white", "logo-base": { "dark": "#fff", "light": "#000" }, - "logo-shape": { "dark": "#000", "light": "#fff" }, + "logo-shape": { "dark": "#000", "light": "#fff" } }, "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } }, From d6b6487c967a04000db58e622e78cff91fd5bc26 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 27 Jan 2025 18:29:56 +0100 Subject: [PATCH 261/292] Fix sendable warnings (#533) --- Sources/ConnectionPoolModule/NIOLock.swift | 60 +++++++++++-------- .../NIOLockedValueBox.swift | 46 +++++++++++++- Sources/PostgresNIO/New/PSQLTask.swift | 13 ++-- .../PostgresNIO/PostgresDatabase+Query.swift | 2 +- 4 files changed, 86 insertions(+), 35 deletions(-) diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift index 13a9df4a..b6cd7164 100644 --- a/Sources/ConnectionPoolModule/NIOLock.swift +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -24,6 +24,13 @@ import WinSDK import Glibc #elseif canImport(Musl) import Musl +#elseif canImport(Bionic) +import Bionic +#elseif canImport(WASILibc) +import WASILibc +#if canImport(wasi_pthread) +import wasi_pthread +#endif #else #error("The concurrency NIOLock module was unable to identify your C library.") #endif @@ -37,16 +44,16 @@ typealias LockPrimitive = pthread_mutex_t #endif @usableFromInline -enum LockOperations { } +enum LockOperations {} extension LockOperations { @inlinable static func create(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) InitializeSRWLock(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) var attr = pthread_mutexattr_t() pthread_mutexattr_init(&attr) debugOnly { @@ -55,43 +62,43 @@ extension LockOperations { let err = pthread_mutex_init(mutex, &attr) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } @inlinable static func destroy(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) // SRWLOCK does not need to be free'd -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_destroy(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } @inlinable static func lock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) AcquireSRWLockExclusive(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_lock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } @inlinable static func unlock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) ReleaseSRWLockExclusive(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_unlock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } } @@ -129,9 +136,11 @@ final class LockStorage: ManagedBuffer { @inlinable static func create(value: Value) -> Self { let buffer = Self.create(minimumCapacity: 1) { _ in - return value + value } - // Avoid 'unsafeDowncast' as there is a miscompilation on 5.10. + // Intentionally using a force cast here to avoid a miss compiliation in 5.10. + // This is as fast as an unsafeDownCast since ManagedBuffer is inlined and the optimizer + // can eliminate the upcast/downcast pair let storage = buffer as! Self storage.withUnsafeMutablePointers { _, lockPtr in @@ -165,7 +174,7 @@ final class LockStorage: ManagedBuffer { @inlinable func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { try self.withUnsafeMutablePointerToElements { lockPtr in - return try body(lockPtr) + try body(lockPtr) } } @@ -179,17 +188,14 @@ final class LockStorage: ManagedBuffer { } } -extension LockStorage: @unchecked Sendable { } - /// A threading lock based on `libpthread` instead of `libdispatch`. /// -/// - note: ``NIOLock`` has reference semantics. +/// - Note: ``NIOLock`` has reference semantics. /// /// This object provides a lock on top of a single `pthread_mutex_t`. This kind /// of lock is safe to use with `libpthread`-based threading models, such as the /// one used by NIO. On Windows, the lock is based on the substantially similar /// `SRWLOCK` type. -@usableFromInline struct NIOLock { @usableFromInline internal let _storage: LockStorage @@ -220,7 +226,7 @@ struct NIOLock { @inlinable internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { - return try self._storage.withLockPrimitive(body) + try self._storage.withLockPrimitive(body) } } @@ -243,12 +249,12 @@ extension NIOLock { } @inlinable - func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + func withLockVoid(_ body: () throws -> Void) rethrows { try self.withLock(body) } } -extension NIOLock: Sendable {} +extension NIOLock: @unchecked Sendable {} extension UnsafeMutablePointer { @inlinable @@ -264,6 +270,10 @@ extension UnsafeMutablePointer { /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - // FIXME: duplicated with NIO. - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } diff --git a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift index e5a3e6a2..c9cd89e0 100644 --- a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift +++ b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift @@ -17,7 +17,7 @@ /// Provides locked access to `Value`. /// -/// - note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` +/// - Note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` /// alongside a lock behind a reference. /// /// This is no different than creating a ``Lock`` and protecting all @@ -39,8 +39,48 @@ struct NIOLockedValueBox { /// Access the `Value`, allowing mutation of it. @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { - return try self._storage.withLockedValue(mutate) + try self._storage.withLockedValue(mutate) + } + + /// Provides an unsafe view over the lock and its value. + /// + /// This can be beneficial when you require fine grained control over the lock in some + /// situations but don't want lose the benefits of ``withLockedValue(_:)`` in others by + /// switching to ``NIOLock``. + var unsafe: Unsafe { + Unsafe(_storage: self._storage) + } + + /// Provides an unsafe view over the lock and its value. + struct Unsafe { + @usableFromInline + let _storage: LockStorage + + /// Manually acquire the lock. + @inlinable + func lock() { + self._storage.lock() + } + + /// Manually release the lock. + @inlinable + func unlock() { + self._storage.unlock() + } + + /// Mutate the value, assuming the lock has been acquired manually. + /// + /// - Parameter mutate: A closure with scoped access to the value. + /// - Returns: The result of the `mutate` closure. + @inlinable + func withValueAssumingLockIsAcquired( + _ mutate: (_ value: inout Value) throws -> Result + ) rethrows -> Result { + try self._storage.withUnsafeMutablePointerToHeader { value in + try mutate(&value.pointee) + } + } } } -extension NIOLockedValueBox: Sendable where Value: Sendable {} +extension NIOLockedValueBox: @unchecked Sendable where Value: Sendable {} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 363f9394..6106fd21 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -1,7 +1,7 @@ import Logging import NIOCore -enum HandlerTask { +enum HandlerTask: Sendable { case extendedQuery(ExtendedQueryContext) case closeCommand(CloseCommandContext) case startListening(NotificationListener) @@ -31,7 +31,7 @@ enum PSQLTask { } } -final class ExtendedQueryContext { +final class ExtendedQueryContext: Sendable { enum Query { case unnamed(PostgresQuery, EventLoopPromise) case executeStatement(PSQLExecuteStatement, EventLoopPromise) @@ -100,14 +100,15 @@ final class PreparedStatementContext: Sendable { } } -final class CloseCommandContext { +final class CloseCommandContext: Sendable { let target: CloseTarget let logger: Logger let promise: EventLoopPromise - init(target: CloseTarget, - logger: Logger, - promise: EventLoopPromise + init( + target: CloseTarget, + logger: Logger, + promise: EventLoopPromise ) { self.target = target self.logger = logger diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 483d5a7b..8de93814 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -40,7 +40,7 @@ extension PostgresDatabase { } } -public struct PostgresQueryResult { +public struct PostgresQueryResult: Sendable { public let metadata: PostgresQueryMetadata public let rows: [PostgresRow] } From 8d07f2049531a60c08b8dda7011a3ad8ac3c989b Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 10 Feb 2025 16:00:18 +0100 Subject: [PATCH 262/292] Fix Sendable warnings (#536) --- Package.swift | 2 +- Tests/IntegrationTests/PostgresNIOTests.swift | 22 -------- .../New/PostgresChannelHandlerTests.swift | 2 +- .../New/PostgresConnectionTests.swift | 16 +++--- .../New/PostgresRowSequenceTests.swift | 51 +++++++++++-------- 5 files changed, 41 insertions(+), 52 deletions(-) diff --git a/Package.swift b/Package.swift index 5f6562f6..3dd21c3c 100644 --- a/Package.swift +++ b/Package.swift @@ -20,7 +20,7 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-collections.git", from: "1.0.4"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.59.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index ff59209b..9a58f050 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -1032,28 +1032,6 @@ final class PostgresNIOTests: XCTestCase { } } - func testRemoteTLSServer() { - // postgres://uymgphwj:7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA@elmer.db.elephantsql.com:5432/uymgphwj - var conn: PostgresConnection? - let logger = Logger(label: "test") - let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration()) - let config = PostgresConnection.Configuration( - host: "elmer.db.elephantsql.com", - port: 5432, - username: "uymgphwj", - password: "7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA", - database: "uymgphwj", - tls: .require(sslContext) - ) - XCTAssertNoThrow(conn = try PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger).wait()) - defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) - XCTAssertEqual(rows?.count, 1) - let row = rows?.first?.makeRandomAccess() - XCTAssertEqual(row?[data: "version"].string?.contains("PostgreSQL"), true) - } - @available(*, deprecated, message: "Test deprecated functionality") func testFailingTLSConnectionClosesConnection() { // There was a bug (https://github.com/vapor/postgres-nio/issues/133) where we would hit diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index a2c90969..206f38a3 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -124,7 +124,7 @@ class PostgresChannelHandlerTests: XCTestCase { handler ], loop: self.eventLoop) let eventHandler = TestEventHandler() - try embedded.pipeline.addHandler(eventHandler, position: .last).wait() + try embedded.pipeline.syncOperations.addHandler(eventHandler, position: .last) embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil) XCTAssertTrue(embedded.isActive) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 0bc61efd..d0f8e2b0 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -40,10 +40,10 @@ class PostgresConnectionTests: XCTestCase { func testOptionsAreSentOnTheWire() async throws { let eventLoop = NIOAsyncTestingEventLoop() - let channel = await NIOAsyncTestingChannel(handlers: [ - ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), - ], loop: eventLoop) + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) let configuration = { @@ -640,10 +640,10 @@ class PostgresConnectionTests: XCTestCase { func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() - let channel = await NIOAsyncTestingChannel(handlers: [ - ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), - ], loop: eventLoop) + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) let configuration = PostgresConnection.Configuration( diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 816daf04..9d662252 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -1,6 +1,6 @@ import Atomics import NIOEmbedded -import Dispatch +import NIOPosix import XCTest @testable import PostgresNIO import NIOCore @@ -8,10 +8,10 @@ import Logging final class PostgresRowSequenceTests: XCTestCase { let logger = Logger(label: "PSQLRowStreamTests") - let eventLoop = EmbeddedEventLoop() func testBackpressureWorks() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -19,7 +19,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -41,6 +41,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testCancellationWorksWhileIterating() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -48,7 +49,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -72,6 +73,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testCancellationWorksBeforeIterating() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -79,7 +81,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -97,6 +99,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testDroppingTheSequenceCancelsTheSource() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -104,7 +107,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -117,6 +120,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamBasedOnCompletedQuery() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -124,7 +128,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -144,6 +148,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamIfInitializedWithAllData() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -151,7 +156,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -172,6 +177,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamIfInitializedWithError() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -179,7 +185,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -200,6 +206,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testSucceedingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( source: .stream( [ @@ -207,14 +214,14 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: eventLoop, logger: self.logger ) - let rowSequence = stream.asyncSequence() + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() var rowIterator = rowSequence.makeAsyncIterator() - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) } @@ -222,7 +229,7 @@ final class PostgresRowSequenceTests: XCTestCase { let row1 = try await rowIterator.next() XCTAssertEqual(try row1?.decode(Int.self), 0) - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .success("SELECT 1")) } @@ -232,6 +239,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testFailingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( source: .stream( [ @@ -239,14 +247,14 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: eventLoop, logger: self.logger ) - let rowSequence = stream.asyncSequence() + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() var rowIterator = rowSequence.makeAsyncIterator() - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) } @@ -254,7 +262,7 @@ final class PostgresRowSequenceTests: XCTestCase { let row1 = try await rowIterator.next() XCTAssertEqual(try row1?.decode(Int.self), 0) - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) } @@ -268,6 +276,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testAdaptiveRowBufferShrinksAndGrows() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -275,7 +284,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -332,6 +341,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testAdaptiveRowShrinksToMin() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -339,7 +349,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -386,6 +396,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamBufferAcceptsNewRowsEventhoughItDidntAskForIt() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -393,7 +404,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) From 712740b1f528210a3ce05618336f5c7dd2470bb9 Mon Sep 17 00:00:00 2001 From: Stevenson Michel <130018170+thoven87@users.noreply.github.com> Date: Tue, 11 Feb 2025 05:29:25 -0500 Subject: [PATCH 263/292] Add `withTransaction` API (#519) Co-authored-by: Fabian Fett --- Sources/PostgresNIO/Pool/PostgresClient.swift | 22 ++++ .../PostgresClientTests.swift | 104 ++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index ad8a4bf1..e9e947ef 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -307,6 +307,28 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return try await closure(connection) } + + /// Lease a connection for the provided `closure`'s lifetime. + /// A transation starts with call to withConnection + /// A transaction should end with a call to COMMIT or ROLLBACK + /// COMMIT is called upon successful completion and ROLLBACK is called should any steps fail + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withTransaction(_ process: (PostgresConnection) async throws -> Result) async throws -> Result { + try await withConnection { connection in + try await connection.query("BEGIN;", logger: self.backgroundLogger) + do { + let value = try await process(connection) + try await connection.query("COMMIT;", logger: self.backgroundLogger) + return value + } catch { + try await connection.query("ROLLBACK;", logger: self.backgroundLogger) + throw error + } + } + } /// Run a query on the Postgres server the client is connected to. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 579c92cd..167ba298 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -42,6 +42,110 @@ final class PostgresClientTests: XCTestCase { taskGroup.cancelAll() } } + + func testTransaction() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let tableName = "test_client_transactions" + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + do { + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + try await client.query( + """ + CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + uuid UUID NOT NULL + ); + """, + logger: logger + ) + + let iterations = 1000 + + for _ in 0.. Date: Thu, 13 Feb 2025 12:32:47 +0100 Subject: [PATCH 264/292] Improve transaction handling (#538) --- .../Connection/PostgresConnection.swift | 104 ++++++++++++++++++ .../New/PostgresTransactionError.swift | 21 ++++ Sources/PostgresNIO/Pool/PostgresClient.swift | 84 +++++++++++--- .../PostgresClientTests.swift | 12 +- 4 files changed, 199 insertions(+), 22 deletions(-) create mode 100644 Sources/PostgresNIO/New/PostgresTransactionError.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 229cd647..e267d8f9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -530,6 +530,110 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + #if compiler(>=6.0) + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ process: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } + } + #else + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + _ process: (PostgresConnection) async throws -> Result + ) async throws -> Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } + } + #endif } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/New/PostgresTransactionError.swift b/Sources/PostgresNIO/New/PostgresTransactionError.swift new file mode 100644 index 00000000..35038446 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresTransactionError.swift @@ -0,0 +1,21 @@ +/// A wrapper around the errors that can occur during a transaction. +public struct PostgresTransactionError: Error { + + /// The file in which the transaction was started + public var file: String + /// The line in which the transaction was started + public var line: Int + + /// The error thrown when running the `BEGIN` query + public var beginError: Error? + /// The error thrown in the transaction closure + public var closureError: Error? + + /// The error thrown while rolling the transaction back. If the ``closureError`` is set, + /// but the ``rollbackError`` is empty, the rollback was successful. If the ``rollbackError`` + /// is set, the rollback failed. + public var rollbackError: Error? + + /// The error thrown while commiting the transaction. + public var commitError: Error? +} diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index e9e947ef..d54e34eb 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -293,13 +293,13 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return ConnectionAndMetadata(connection: connection, maximalStreamsOnConnection: 1) } } - /// Lease a connection for the provided `closure`'s lifetime. /// /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture /// the provided `PostgresConnection`. /// - Returns: The closure's return value. + @_disfavoredOverload public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { let connection = try await self.leaseConnection() @@ -307,28 +307,80 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return try await closure(connection) } - + + #if compiler(>=6.0) /// Lease a connection for the provided `closure`'s lifetime. - /// A transation starts with call to withConnection - /// A transaction should end with a call to COMMIT or ROLLBACK - /// COMMIT is called upon successful completion and ROLLBACK is called should any steps fail /// /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture /// the provided `PostgresConnection`. /// - Returns: The closure's return value. - public func withTransaction(_ process: (PostgresConnection) async throws -> Result) async throws -> Result { - try await withConnection { connection in - try await connection.query("BEGIN;", logger: self.backgroundLogger) - do { - let value = try await process(connection) - try await connection.query("COMMIT;", logger: self.backgroundLogger) - return value - } catch { - try await connection.query("ROLLBACK;", logger: self.backgroundLogger) - throw error - } + public func withConnection( + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + let connection = try await self.leaseConnection() + + defer { self.pool.releaseConnection(connection) } + + return try await closure(connection) + } + + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + try await self.withConnection { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, closure) + } + } + #else + + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + _ closure: (PostgresConnection) async throws -> Result + ) async throws -> Result { + try await self.withConnection { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, closure) } } + #endif /// Run a query on the Postgres server the client is connected to. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 167ba298..34a8ad2a 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -77,7 +77,7 @@ final class PostgresClientTests: XCTestCase { for _ in 0.. Date: Wed, 2 Apr 2025 13:20:07 +0200 Subject: [PATCH 265/292] Move ConnectionPool test-utils into separate target (#544) --- Package.swift | 8 +++++ .../ConnectionPoolTestUtils}/MockClock.swift | 33 ++++++++++--------- .../MockConnection.swift | 22 ++++++------- .../MockConnectionFactory.swift | 27 ++++++++------- .../MockPingPongBehaviour.swift | 10 +++--- .../ConnectionPoolTestUtils/MockRequest.swift | 29 ++++++++++++++++ .../ConnectionPoolTests.swift | 3 +- .../ConnectionRequestTests.swift | 1 + .../Mocks/MockRequest.swift | 28 ---------------- .../NoKeepAliveBehaviorTests.swift | 1 + ...oolStateMachine+ConnectionGroupTests.swift | 3 +- ...oolStateMachine+ConnectionStateTests.swift | 1 + .../PoolStateMachine+RequestQueueTests.swift | 1 + .../PoolStateMachineTests.swift | 1 + 14 files changed, 95 insertions(+), 73 deletions(-) rename {Tests/ConnectionPoolModuleTests/Mocks => Sources/ConnectionPoolTestUtils}/MockClock.swift (84%) rename {Tests/ConnectionPoolModuleTests/Mocks => Sources/ConnectionPoolTestUtils}/MockConnection.swift (86%) rename {Tests/ConnectionPoolModuleTests/Mocks => Sources/ConnectionPoolTestUtils}/MockConnectionFactory.swift (79%) rename {Tests/ConnectionPoolModuleTests/Mocks => Sources/ConnectionPoolTestUtils}/MockPingPongBehaviour.swift (84%) create mode 100644 Sources/ConnectionPoolTestUtils/MockRequest.swift delete mode 100644 Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift diff --git a/Package.swift b/Package.swift index 3dd21c3c..ff071f88 100644 --- a/Package.swift +++ b/Package.swift @@ -57,6 +57,13 @@ let package = Package( path: "Sources/ConnectionPoolModule", swiftSettings: swiftSettings ), + .target( + name: "ConnectionPoolTestUtils", + dependencies: [ + "_ConnectionPoolModule", + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + ] + ), .testTarget( name: "PostgresNIOTests", dependencies: [ @@ -70,6 +77,7 @@ let package = Package( name: "ConnectionPoolModuleTests", dependencies: [ .target(name: "_ConnectionPoolModule"), + .target(name: "ConnectionPoolTestUtils"), .product(name: "DequeModule", package: "swift-collections"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift b/Sources/ConnectionPoolTestUtils/MockClock.swift similarity index 84% rename from Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift rename to Sources/ConnectionPoolTestUtils/MockClock.swift index cd08d54e..34bf17e3 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift +++ b/Sources/ConnectionPoolTestUtils/MockClock.swift @@ -1,31 +1,32 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import Atomics import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockClock: Clock { - struct Instant: InstantProtocol, Comparable { - typealias Duration = Swift.Duration +public final class MockClock: Clock { + public struct Instant: InstantProtocol, Comparable { + public typealias Duration = Swift.Duration - func advanced(by duration: Self.Duration) -> Self { + public func advanced(by duration: Self.Duration) -> Self { .init(self.base + duration) } - func duration(to other: Self) -> Self.Duration { + public func duration(to other: Self) -> Self.Duration { self.base - other.base } private var base: Swift.Duration - init(_ base: Duration) { + public init(_ base: Duration) { self.base = base } - static func < (lhs: Self, rhs: Self) -> Bool { + public static func < (lhs: Self, rhs: Self) -> Bool { lhs.base < rhs.base } - static func == (lhs: Self, rhs: Self) -> Bool { + public static func == (lhs: Self, rhs: Self) -> Bool { lhs.base == rhs.base } } @@ -58,16 +59,18 @@ final class MockClock: Clock { var continuation: CheckedContinuation } - typealias Duration = Swift.Duration + public typealias Duration = Swift.Duration - var minimumResolution: Duration { .nanoseconds(1) } + public var minimumResolution: Duration { .nanoseconds(1) } - var now: Instant { self.stateBox.withLockedValue { $0.now } } + public var now: Instant { self.stateBox.withLockedValue { $0.now } } private let stateBox = NIOLockedValueBox(State()) private let waiterIDGenerator = ManagedAtomic(0) - func sleep(until deadline: Instant, tolerance: Duration?) async throws { + public init() {} + + public func sleep(until deadline: Instant, tolerance: Duration?) async throws { let waiterID = self.waiterIDGenerator.loadThenWrappingIncrement(ordering: .relaxed) return try await withTaskCancellationHandler { @@ -131,7 +134,7 @@ final class MockClock: Clock { } @discardableResult - func nextTimerScheduled() async -> Instant { + public func nextTimerScheduled() async -> Instant { await withCheckedContinuation { (continuation: CheckedContinuation) in let instant = self.stateBox.withLockedValue { state -> Instant? in if let scheduled = state.nextDeadlines.popFirst() { @@ -149,7 +152,7 @@ final class MockClock: Clock { } } - func advance(to deadline: Instant) { + public func advance(to deadline: Instant) { let waiters = self.stateBox.withLockedValue { state -> ArraySlice in precondition(deadline > state.now, "Time can only move forward") state.now = deadline diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Sources/ConnectionPoolTestUtils/MockConnection.swift similarity index 86% rename from Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift rename to Sources/ConnectionPoolTestUtils/MockConnection.swift index f826ea04..db5c3ef7 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift +++ b/Sources/ConnectionPoolTestUtils/MockConnection.swift @@ -1,11 +1,11 @@ +import _ConnectionPoolModule import DequeModule -@testable import _ConnectionPoolModule +import NIOConcurrencyHelpers -// Sendability enforced through the lock -final class MockConnection: PooledConnection, Sendable { - typealias ID = Int +public final class MockConnection: PooledConnection, Sendable { + public typealias ID = Int - let id: ID + public let id: ID private enum State { case running([CheckedContinuation], [@Sendable ((any Error)?) -> ()]) @@ -15,11 +15,11 @@ final class MockConnection: PooledConnection, Sendable { private let lock: NIOLockedValueBox = NIOLockedValueBox(.running([], [])) - init(id: Int) { + public init(id: Int) { self.id = id } - var signalToClose: Void { + public var signalToClose: Void { get async throws { try await withCheckedThrowingContinuation { continuation in let runRightAway = self.lock.withLockedValue { state -> Bool in @@ -41,7 +41,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { let enqueued = self.lock.withLockedValue { state -> Bool in switch state { case .closed: @@ -64,7 +64,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func close() { + public func close() { let continuations = self.lock.withLockedValue { state -> [CheckedContinuation] in switch state { case .running(let continuations, let callbacks): @@ -81,7 +81,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func closeIfClosing() { + public func closeIfClosing() { let callbacks = self.lock.withLockedValue { state -> [@Sendable ((any Error)?) -> ()] in switch state { case .running, .closed: @@ -100,7 +100,7 @@ final class MockConnection: PooledConnection, Sendable { } extension MockConnection: CustomStringConvertible { - var description: String { + public var description: String { let state = self.lock.withLockedValue { $0 } return "MockConnection(id: \(self.id), state: \(state))" } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift similarity index 79% rename from Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift rename to Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift index 1c9bfff8..59552d30 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift +++ b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift @@ -1,14 +1,15 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockConnectionFactory: Sendable where Clock.Duration == Duration { - typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator - typealias Request = ConnectionRequest - typealias KeepAliveBehavior = MockPingPongBehavior - typealias MetricsDelegate = NoOpConnectionPoolMetrics - typealias ConnectionID = Int - typealias Connection = MockConnection +public final class MockConnectionFactory: Sendable where Clock.Duration == Duration { + public typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + public typealias Request = ConnectionRequest + public typealias KeepAliveBehavior = MockPingPongBehavior + public typealias MetricsDelegate = NoOpConnectionPoolMetrics + public typealias ConnectionID = Int + public typealias Connection = MockConnection let stateBox = NIOLockedValueBox(State()) @@ -20,15 +21,17 @@ final class MockConnectionFactory: Sendable where Clo var runningConnections = [ConnectionID: Connection]() } - var pendingConnectionAttemptsCount: Int { + public init() {} + + public var pendingConnectionAttemptsCount: Int { self.stateBox.withLockedValue { $0.attempts.count } } - var runningConnections: [Connection] { + public var runningConnections: [Connection] { self.stateBox.withLockedValue { Array($0.runningConnections.values) } } - func makeConnection( + public func makeConnection( id: Int, for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> ) async throws -> ConnectionAndMetadata { @@ -52,7 +55,7 @@ final class MockConnectionFactory: Sendable where Clo } @discardableResult - func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + public func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in if let attempt = state.attempts.popFirst() { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift similarity index 84% rename from Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift rename to Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift index 637f096c..5a274079 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift +++ b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift @@ -2,8 +2,8 @@ import DequeModule @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockPingPongBehavior: ConnectionKeepAliveBehavior { - let keepAliveFrequency: Duration? +public final class MockPingPongBehavior: ConnectionKeepAliveBehavior { + public let keepAliveFrequency: Duration? let stateBox = NIOLockedValueBox(State()) @@ -13,11 +13,11 @@ final class MockPingPongBehavior: ConnectionKeepAl var waiter = Deque), Never>>() } - init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { + public init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { self.keepAliveFrequency = keepAliveFrequency } - func runKeepAlive(for connection: Connection) async throws { + public func runKeepAlive(for connection: Connection) async throws { precondition(self.keepAliveFrequency != nil) // we currently don't support cancellation when creating a connection @@ -40,7 +40,7 @@ final class MockPingPongBehavior: ConnectionKeepAl } @discardableResult - func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { + public func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { let (connection, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(Connection, CheckedContinuation), Never>) in let run = self.stateBox.withLockedValue { state -> (Connection, CheckedContinuation)? in if let run = state.runs.popFirst() { diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift new file mode 100644 index 00000000..06fc49bc --- /dev/null +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -0,0 +1,29 @@ +import _ConnectionPoolModule + +public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { + public typealias Connection = MockConnection + + public struct ID: Hashable { + var objectID: ObjectIdentifier + + init(_ request: MockRequest) { + self.objectID = ObjectIdentifier(request) + } + } + + public init() {} + + public var id: ID { ID(self) } + + public static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { + lhs.id == rhs.id + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + public func complete(with: Result) { + + } +} diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 3c0e7a6b..9b3d5871 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -1,7 +1,8 @@ @testable import _ConnectionPoolModule import Atomics -import XCTest +import ConnectionPoolTestUtils import NIOEmbedded +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) final class ConnectionPoolTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index 5845267f..cbdc4f65 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import ConnectionPoolTestUtils import XCTest final class ConnectionRequestTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift deleted file mode 100644 index 6aaa9c91..00000000 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift +++ /dev/null @@ -1,28 +0,0 @@ -import _ConnectionPoolModule - -final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { - typealias Connection = MockConnection - - struct ID: Hashable { - var objectID: ObjectIdentifier - - init(_ request: MockRequest) { - self.objectID = ObjectIdentifier(request) - } - } - - var id: ID { ID(self) } - - - static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { - lhs.id == rhs.id - } - - func hash(into hasher: inout Hasher) { - hasher.combine(self.id) - } - - func complete(with: Result) { - - } -} diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift index b817ce19..4ddad00d 100644 --- a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -1,4 +1,5 @@ import _ConnectionPoolModule +import ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 6b8d6c6e..3ec7dc80 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -1,5 +1,6 @@ -import XCTest @testable import _ConnectionPoolModule +import ConnectionPoolTestUtils +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) final class PoolStateMachine_ConnectionGroupTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index bc4c2c4b..77ad713d 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index 0231da51..2ec450a6 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index 2f3ae617..ca5cb54d 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,3 +1,4 @@ +import ConnectionPoolTestUtils import XCTest @testable import _ConnectionPoolModule From b775835ff0dbef8db8af178fb9eff400bbad1582 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 2 Apr 2025 16:32:34 +0200 Subject: [PATCH 266/292] Add Benchmarks for ConnectionPool (#545) --- Benchmarks/.gitignore | 1 + .../ConnectionPoolBenchmarks.swift | 51 +++++++++++++++++++ Benchmarks/Package.swift | 28 ++++++++++ Package.swift | 9 ++-- .../MockConnectionFactory.swift | 15 +++++- .../MockPingPongBehaviour.swift | 3 +- .../ConnectionPoolTests.swift | 2 +- .../ConnectionRequestTests.swift | 2 +- .../NoKeepAliveBehaviorTests.swift | 2 +- ...oolStateMachine+ConnectionGroupTests.swift | 2 +- ...oolStateMachine+ConnectionStateTests.swift | 2 +- .../PoolStateMachine+RequestQueueTests.swift | 2 +- .../PoolStateMachineTests.swift | 4 +- 13 files changed, 110 insertions(+), 13 deletions(-) create mode 100644 Benchmarks/.gitignore create mode 100644 Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift create mode 100644 Benchmarks/Package.swift diff --git a/Benchmarks/.gitignore b/Benchmarks/.gitignore new file mode 100644 index 00000000..24e5b0a1 --- /dev/null +++ b/Benchmarks/.gitignore @@ -0,0 +1 @@ +.build diff --git a/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift new file mode 100644 index 00000000..98f21f62 --- /dev/null +++ b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift @@ -0,0 +1,51 @@ +import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import Benchmark + +let benchmarks: @Sendable () -> Void = { + Benchmark("Minimal benchmark", configuration: .init(scalingFactor: .kilo)) { benchmark in + let clock = MockClock() + let factory = MockConnectionFactory(autoMaxStreams: 1) + var configuration = ConnectionPoolConfiguration() + configuration.maximumConnectionSoftLimit = 50 + configuration.maximumConnectionHardLimit = 50 + + let pool = ConnectionPool( + configuration: configuration, + idGenerator: ConnectionIDGenerator(), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup { taskGroup in + + taskGroup.addTask { + await pool.run() + } + + let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + + for parallel in 0..: Sendable wh var runningConnections = [ConnectionID: Connection]() } - public init() {} + let autoMaxStreams: UInt16? + + public init(autoMaxStreams: UInt16? = nil) { + self.autoMaxStreams = autoMaxStreams + } public var pendingConnectionAttemptsCount: Int { self.stateBox.withLockedValue { $0.attempts.count } @@ -35,6 +39,15 @@ public final class MockConnectionFactory: Sendable wh id: Int, for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> ) async throws -> ConnectionAndMetadata { + if let autoMaxStreams = self.autoMaxStreams { + let connection = MockConnection(id: id) + Task { + try? await connection.signalToClose + connection.closeIfClosing() + } + return .init(connection: connection, maximalStreamsOnConnection: autoMaxStreams) + } + // we currently don't support cancellation when creating a connection let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in diff --git a/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift index 5a274079..de1a7275 100644 --- a/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift +++ b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift @@ -1,5 +1,6 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) public final class MockPingPongBehavior: ConnectionKeepAliveBehavior { diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 9b3d5871..c745b4a0 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -1,6 +1,6 @@ @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils import Atomics -import ConnectionPoolTestUtils import NIOEmbedded import XCTest diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index cbdc4f65..537efbd9 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -1,5 +1,5 @@ @testable import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest final class ConnectionRequestTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift index 4ddad00d..b1b54023 100644 --- a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -1,5 +1,5 @@ import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 3ec7dc80..b09bfcb4 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -1,5 +1,5 @@ @testable import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index 77ad713d..7dd2b726 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -1,5 +1,5 @@ @testable import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index 2ec450a6..b74b86cc 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -1,5 +1,5 @@ @testable import _ConnectionPoolModule -import ConnectionPoolTestUtils +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index ca5cb54d..c0b6ddcd 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,6 +1,6 @@ -import ConnectionPoolTestUtils -import XCTest @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) typealias TestPoolStateMachine = PoolStateMachine< From ecbc3eb092cb41015c02643ff5258cb94ccbd342 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 2 Apr 2025 17:59:26 +0200 Subject: [PATCH 267/292] Make ConnectionPool faster (#546) --- Sources/ConnectionPoolModule/ConnectionRequest.swift | 5 ++++- .../PoolStateMachine+ConnectionGroup.swift | 10 ++++++++-- .../PoolStateMachine+ConnectionState.swift | 4 ++-- Sources/ConnectionPoolTestUtils/MockRequest.swift | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 19ed9bd2..1d1c55da 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -21,7 +21,8 @@ public struct ConnectionRequest: ConnectionRequest } } -fileprivate let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() +@usableFromInline +let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) extension ConnectionPool where Request == ConnectionRequest { @@ -44,6 +45,7 @@ extension ConnectionPool where Request == ConnectionRequest { ) } + @inlinable public func leaseConnection() async throws -> Connection { let requestID = requestIDGenerator.next() @@ -67,6 +69,7 @@ extension ConnectionPool where Request == ConnectionRequest { return connection } + @inlinable public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { let connection = try await self.leaseConnection() defer { self.releaseConnection(connection) } diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index f26f244d..a8e97ffd 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -132,6 +132,12 @@ extension PoolStateMachine { @usableFromInline var info: ConnectionAvailableInfo + + @inlinable + init(use: ConnectionUse, info: ConnectionAvailableInfo) { + self.use = use + self.info = info + } } mutating func refillConnections() -> [ConnectionRequest] { @@ -623,7 +629,7 @@ extension PoolStateMachine { // MARK: - Private functions - - @usableFromInline + @inlinable /*private*/ func getConnectionUse(index: Int) -> ConnectionUse { switch index { case 0.. AvailableConnectionContext { precondition(self.connections[index].isAvailable) let use = self.getConnectionUse(index: index) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index 2fb68a2d..9912f13a 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -164,7 +164,7 @@ extension PoolStateMachine { } } - @usableFromInline + @inlinable var isLeased: Bool { switch self.state { case .leased: @@ -174,7 +174,7 @@ extension PoolStateMachine { } } - @usableFromInline + @inlinable var isConnected: Bool { switch self.state { case .idle, .leased: diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift index 06fc49bc..5e4e2fc0 100644 --- a/Sources/ConnectionPoolTestUtils/MockRequest.swift +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -3,7 +3,7 @@ import _ConnectionPoolModule public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { public typealias Connection = MockConnection - public struct ID: Hashable { + public struct ID: Hashable, Sendable { var objectID: ObjectIdentifier init(_ request: MockRequest) { From 3cac9571a3467cf1ff431e46e98f44761bfdb600 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Apr 2025 12:28:21 +0200 Subject: [PATCH 268/292] Drop support for Swift 5.9 (#549) --- .github/workflows/test.yml | 6 +++--- Package.swift | 2 +- Sources/ConnectionPoolModule/ConnectionPool.swift | 10 ---------- 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8364e8ae..21271124 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,9 +18,9 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.9-jammy - - swift:5.10-noble + - swift:5.10-jammy - swift:6.0-noble + - swift:6.1-noble - swiftlang/swift:nightly-main-jammy container: ${{ matrix.swift-image }} runs-on: ubuntu-latest @@ -58,7 +58,7 @@ jobs: - postgres-image: postgres:12 postgres-auth: trust container: - image: swift:5.10-noble + image: swift:6.1-noble volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: diff --git a/Package.swift b/Package.swift index 477a3256..8d150788 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.9 +// swift-tools-version:5.10 import PackageDescription let swiftSettings: [SwiftSetting] = [ diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 5cdb980d..dc564ffc 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -273,13 +273,11 @@ public final class ConnectionPool< public func run() async { await withTaskCancellationHandler { - #if os(Linux) || compiler(>=5.9) if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { return await withDiscardingTaskGroup() { taskGroup in await self.run(in: &taskGroup) } } - #endif return await withTaskGroup(of: Void.self) { taskGroup in await self.run(in: &taskGroup) } @@ -313,14 +311,12 @@ public final class ConnectionPool< case scheduleTimer(StateMachine.Timer) } - #if os(Linux) || compiler(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) private func run(in taskGroup: inout DiscardingTaskGroup) async { for await event in self.eventStream { self.runEvent(event, in: &taskGroup) } } - #endif private func run(in taskGroup: inout TaskGroup) async { var running = 0 @@ -507,11 +503,7 @@ public final class ConnectionPool< await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in taskGroup.addTask { do { - #if os(Linux) || compiler(>=5.9) try await self.clock.sleep(for: timer.duration) - #else - try await self.clock.sleep(until: self.clock.now.advanced(by: timer.duration), tolerance: nil) - #endif return .timerTriggered } catch { return .timerCancelled @@ -579,7 +571,6 @@ protocol TaskGroupProtocol { mutating func addTask_(operation: @escaping @Sendable () async -> Void) } -#if os(Linux) || swift(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) extension DiscardingTaskGroup: TaskGroupProtocol { @inlinable @@ -587,7 +578,6 @@ extension DiscardingTaskGroup: TaskGroupProtocol { self.addTask(priority: nil, operation: operation) } } -#endif extension TaskGroup: TaskGroupProtocol { @inlinable From eeb29e0b37c5b8a4ea8061d82fcf5058eccd7577 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Apr 2025 12:46:03 +0200 Subject: [PATCH 269/292] Ensure ltree works (#548) --- .../New/Data/String+PostgresCodable.swift | 9 ++- .../PostgresClientTests.swift | 55 +++++++++++++++++++ .../New/Data/String+PSQLCodableTests.swift | 12 ---- .../New/PostgresRowTests.swift | 8 +-- 4 files changed, 67 insertions(+), 17 deletions(-) diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 41091ab3..6bd09e78 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -36,13 +36,20 @@ extension String: PostgresDecodable { // we can force unwrap here, since this method only fails if there are not enough // bytes available. self = buffer.readString(length: buffer.readableBytes)! + case (_, .uuid): guard let uuid = try? UUID(from: &buffer, type: .uuid, format: format, context: context) else { throw PostgresDecodingError.Code.failure } self = uuid.uuidString + default: - throw PostgresDecodingError.Code.typeMismatch + // We should eagerly try to convert any datatype into a String. For example the oid + // for ltree isn't static. For this reason we should just try to convert anything. + guard let string = buffer.readString(length: buffer.readableBytes) else { + throw PostgresDecodingError.Code.typeMismatch + } + self = string } } } diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 34a8ad2a..eaf3663f 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -292,6 +292,61 @@ final class PostgresClientTests: XCTestCase { XCTFail("Unexpected error: \(String(reflecting: error))") } } + + func testLTree() async throws { + let tableName = "test_client_ltree" + + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + try await client.query("CREATE EXTENSION IF NOT EXISTS ltree;") + + try await client.query("DROP TABLE IF EXISTS \"\(unescaped: tableName)\";") + + try await client.query( + """ + CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" ( + id SERIAL PRIMARY KEY, + label ltree NOT NULL + ); + """ + ) + + try await client.query( + """ + INSERT INTO "\(unescaped: tableName)" (label) VALUES ('foo.bar.baz') + """ + ) + + let rows = try await client.query( + """ + SELECT id, label FROM "\(unescaped: tableName)" WHERE label ~ 'foo.*' + """ + ) + + var count = 0 + for try await (id, label) in rows.decode((Int, String).self) { + count += 1 + } + XCTAssertEqual(count, 1) + + taskGroup.cancelAll() + } + } + } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 6ff35130..aadeabff 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -31,18 +31,6 @@ class String_PSQLCodableTests: XCTestCase { } } - func testDecodeFailureFromInvalidType() { - let buffer = ByteBuffer() - let dataTypes: [PostgresDataType] = [.bool, .float4Array, .float8Array] - - for dataType in dataTypes { - var loopBuffer = buffer - XCTAssertThrowsError(try String(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) - } - } - } - func testDecodeFromUUID() { let uuid = UUID() var buffer = ByteBuffer() diff --git a/Tests/PostgresNIOTests/New/PostgresRowTests.swift b/Tests/PostgresNIOTests/New/PostgresRowTests.swift index 7aa4c7e6..4a34e3b0 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowTests.swift @@ -172,7 +172,7 @@ final class PostgresRowTests: XCTestCase { name: "name", tableOID: 1, columnAttributeNumber: 1, - dataType: .int8, + dataType: .text, dataTypeSize: 0, dataTypeModifier: 0, format: .binary @@ -185,7 +185,7 @@ final class PostgresRowTests: XCTestCase { columns: rowDescription ) - XCTAssertThrowsError(try row.decode((UUID?, String).self)) { error in + XCTAssertThrowsError(try row.decode((UUID?, Int).self)) { error in guard let psqlError = error as? PostgresDecodingError else { return XCTFail("Unexpected error type") } XCTAssertEqual(psqlError.columnName, "name") @@ -194,8 +194,8 @@ final class PostgresRowTests: XCTestCase { XCTAssertEqual(psqlError.file, #fileID) XCTAssertEqual(psqlError.postgresData, ByteBuffer(integer: 123)) XCTAssertEqual(psqlError.postgresFormat, .binary) - XCTAssertEqual(psqlError.postgresType, .int8) - XCTAssert(psqlError.targetType == String.self) + XCTAssertEqual(psqlError.postgresType, .text) + XCTAssert(psqlError.targetType == Int.self) } } } From f78b2e3b3e5765785bf69cf58e0cf539a1123872 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 28 Apr 2025 19:03:19 +0200 Subject: [PATCH 270/292] Add serial pool benchmark (#550) --- .../ConnectionPoolBenchmarks.swift | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift index 98f21f62..9cc535d4 100644 --- a/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift +++ b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift @@ -3,7 +3,7 @@ import _ConnectionPoolTestUtils import Benchmark let benchmarks: @Sendable () -> Void = { - Benchmark("Minimal benchmark", configuration: .init(scalingFactor: .kilo)) { benchmark in + Benchmark("Lease/Release 1k requests: 50 parallel", configuration: .init(scalingFactor: .kilo)) { benchmark in let clock = MockClock() let factory = MockConnectionFactory(autoMaxStreams: 1) var configuration = ConnectionPoolConfiguration() @@ -28,6 +28,8 @@ let benchmarks: @Sendable () -> Void = { let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + benchmark.startMeasurement() + for parallel in 0.. Void = { for i in 0..(autoMaxStreams: 1) + var configuration = ConnectionPoolConfiguration() + configuration.maximumConnectionSoftLimit = 50 + configuration.maximumConnectionHardLimit = 50 + + let pool = ConnectionPool( + configuration: configuration, + idGenerator: ConnectionIDGenerator(), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup { taskGroup in + + taskGroup.addTask { + await pool.run() + } + + let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + + benchmark.startMeasurement() + + for _ in benchmark.scaledIterations { + do { + try await pool.withConnection { connection in + blackHole(connection) + } + } catch { + fatalError() + } + } + + benchmark.stopMeasurement() + taskGroup.cancelAll() } } From d2d8a38be26b2a7a6a30673e347f42485c028de6 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Mon, 12 May 2025 17:15:34 +0200 Subject: [PATCH 271/292] Add `@inlinable` to `ConnectionPool.run()` (#552) Previously the `ConnectionPool.run()` method wasn't marked as `@inlinable`, because of this we missed an opportunity to specialize the code that is run as part of the events. --- Sources/ConnectionPoolModule/ConnectionPool.swift | 10 +++++++--- Sources/ConnectionPoolModule/PoolStateMachine.swift | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index dc564ffc..b460b263 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -271,6 +271,7 @@ public final class ConnectionPool< } } + @inlinable public func run() async { await withTaskCancellationHandler { if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { @@ -312,13 +313,15 @@ public final class ConnectionPool< } @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) - private func run(in taskGroup: inout DiscardingTaskGroup) async { + @inlinable + /* private */ func run(in taskGroup: inout DiscardingTaskGroup) async { for await event in self.eventStream { self.runEvent(event, in: &taskGroup) } } - private func run(in taskGroup: inout TaskGroup) async { + @inlinable + /* private */ func run(in taskGroup: inout TaskGroup) async { var running = 0 for await event in self.eventStream { running += 1 @@ -331,7 +334,8 @@ public final class ConnectionPool< } } - private func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { + @inlinable + /* private */ func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { switch event { case .makeConnection(let request): self.makeConnection(for: request, in: &taskGroup) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 6e41f730..8d995fa2 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -434,6 +434,7 @@ struct PoolStateMachine< fatalError("Unimplemented") } + @usableFromInline mutating func triggerForceShutdown() -> Action { switch self.poolState { case .running: From 44c7b059ecd159dddce593c24a8dad0ad7c5a321 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 14 May 2025 22:34:29 +0200 Subject: [PATCH 272/292] Make SASL faster (#553) --- .../SASLAuthentication+SCRAM-SHA256.swift | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index 2a717b6b..6d8f0868 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -292,7 +292,7 @@ internal struct SHA256: SASLAuthenticationMechanism { /// authenticating user. If the closure throws, authentication /// immediately fails with the thrown error. internal init(username: String, password: @escaping () throws -> String) { - self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().data(using: .utf8)!), []) }, bindingInfo: .unsupported) + self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().utf8), []) }, bindingInfo: .unsupported) } /// Set up a server-side `SCRAM-SHA-256` authentication. @@ -338,7 +338,7 @@ internal struct SHA256_PLUS: SASLAuthenticationMechanism { /// - channelBindingData: The appropriate data associated with the RFC5056 /// channel binding specified. internal init(username: String, password: @escaping () throws -> String, channelBindingName: String, channelBindingData: [UInt8]) { - self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().data(using: .utf8)!), []) }, bindingInfo: .bind(channelBindingName, channelBindingData)) + self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().utf8), []) }, bindingInfo: .bind(channelBindingName, channelBindingData)) } /// Set up a server-side `SCRAM-SHA-256` authentication. @@ -467,7 +467,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { // Calculate `AuthMessage`, `ClientSignature`, and `ClientProof` let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations) - let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let clientKey = HMAC.authenticationCode(for: Data("Client Key".utf8), using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) @@ -501,7 +501,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { switch incomingAttributes.first { case .v(let verifier): // Verify server signature - let serverKey = HMAC.authenticationCode(for: "Server Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let serverKey = HMAC.authenticationCode(for: Data("Server Key".utf8), using: .init(data: saltedPassword)) let serverSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: serverKey)) guard Array(serverSignature) == verifier else { @@ -585,7 +585,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { guard nonce == repeatNonce else { throw SASLAuthenticationError.genericAuthenticationFailure } // Compute client signature - let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let clientKey = HMAC.authenticationCode(for: Data("Client Key".utf8), using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) @@ -604,7 +604,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { guard storedKey == restoredKey else { throw SCRAMServerError.invalidProof } // Compute server signature - let serverKey = HMAC.authenticationCode(for: "Server Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let serverKey = HMAC.authenticationCode(for: Data("Server Key".utf8), using: .init(data: saltedPassword)) let serverSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: serverKey)) // Generate a `server-final-message` @@ -644,11 +644,16 @@ private func Hi(string: [UInt8], salt: [UInt8], iterations: UInt32) -> [UInt8] { let key = SymmetricKey(data: string) var Ui = HMAC.authenticationCode(for: salt + [0x00, 0x00, 0x00, 0x01], using: key) // salt + 0x00000001 as big-endian var Hi = Array(Ui) - + var uiData = [UInt8]() + uiData.reserveCapacity(32) + Hi.withUnsafeMutableBytes { Hibuf -> Void in for _ in 2...iterations { - Ui = HMAC.authenticationCode(for: Data(Ui), using: key) - + uiData.removeAll(keepingCapacity: true) + uiData.append(contentsOf: Ui) + + Ui = HMAC.authenticationCode(for: uiData, using: key) + Ui.withUnsafeBytes { Uibuf -> Void in for i in 0.. Date: Thu, 15 May 2025 01:11:26 +0200 Subject: [PATCH 273/292] Make SASL really fast (#554) --- Package.swift | 1 + .../SASLAuthentication+SCRAM-SHA256.swift | 39 +++++++------------ 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/Package.swift b/Package.swift index 8d150788..b125f7a6 100644 --- a/Package.swift +++ b/Package.swift @@ -36,6 +36,7 @@ let package = Package( .target(name: "_ConnectionPoolModule"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Crypto", package: "swift-crypto"), + .product(name: "_CryptoExtras", package: "swift-crypto"), .product(name: "Logging", package: "swift-log"), .product(name: "Metrics", package: "swift-metrics"), .product(name: "NIO", package: "swift-nio"), diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index 6d8f0868..53e9c3f7 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -1,4 +1,5 @@ import Crypto +import _CryptoExtras import Foundation extension UInt8 { @@ -466,8 +467,8 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { // TODO: Perform `Normalize(password)`, aka the SASLprep profile (RFC4013) of stringprep (RFC3454) // Calculate `AuthMessage`, `ClientSignature`, and `ClientProof` - let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations) - let clientKey = HMAC.authenticationCode(for: Data("Client Key".utf8), using: .init(data: saltedPassword)) + let saltedPassword = try Hi(string: password, salt: serverSalt, iterations: serverIterations) + let clientKey = HMAC.authenticationCode(for: Data("Client Key".utf8), using: saltedPassword) let storedKey = SHA256.hash(data: Data(clientKey)) var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) @@ -485,9 +486,11 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(.comma) guard let proofPart = SCRAMMessageParser.serialize([.p(Array(clientProof))]) else { throw SASLAuthenticationError.genericAuthenticationFailure } clientFinalMessage.append(contentsOf: proofPart) - + + let saltedPasswordBytes = saltedPassword.withUnsafeBytes { [UInt8]($0) } + // Save state and send - self.state = .clientSentFinalMessage(saltedPassword: saltedPassword, authMessage: authMessage) + self.state = .clientSentFinalMessage(saltedPassword: saltedPasswordBytes, authMessage: authMessage) return .continue(response: clientFinalMessage) } @@ -640,24 +643,12 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { HMAC() == output length of H(). ```` */ -private func Hi(string: [UInt8], salt: [UInt8], iterations: UInt32) -> [UInt8] { - let key = SymmetricKey(data: string) - var Ui = HMAC.authenticationCode(for: salt + [0x00, 0x00, 0x00, 0x01], using: key) // salt + 0x00000001 as big-endian - var Hi = Array(Ui) - var uiData = [UInt8]() - uiData.reserveCapacity(32) - - Hi.withUnsafeMutableBytes { Hibuf -> Void in - for _ in 2...iterations { - uiData.removeAll(keepingCapacity: true) - uiData.append(contentsOf: Ui) - - Ui = HMAC.authenticationCode(for: uiData, using: key) - - Ui.withUnsafeBytes { Uibuf -> Void in - for i in 0.. SymmetricKey { + try KDF.Insecure.PBKDF2.deriveKey( + from: string, + salt: salt, + using: .sha256, + outputByteCount: 32, + unsafeUncheckedRounds: Int(iterations) + ) } From fc357052754e6d704354a4e60dc34c9401305109 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 15 May 2025 01:29:48 -0500 Subject: [PATCH 274/292] Require the correct minimum version of swift-crypto (#555) KDF was added to swift-crypto in 3.9.0. We now require that as a minimum via the SCRAM-SHA-256 authentication method. --- Package.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index b125f7a6..af2d07ae 100644 --- a/Package.swift +++ b/Package.swift @@ -24,7 +24,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), - .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), + .package(url: "https://github.com/apple/swift-crypto.git", "3.9.0" ..< "4.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.5.0"), From 9ef06112c41feb7170c8c9116361f170e2a2b1d1 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Thu, 15 May 2025 02:29:10 -0500 Subject: [PATCH 275/292] Update README.md to reflect current minimum Swift version (#556) Update README.md to reflect current Swift minimum --- README.md | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index bc56953b..6d03b8da 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,5 @@

- - - - PostgresNIO - +PostgresNIO

@@ -16,7 +12,7 @@ Continuous Integration - Swift 5.8+ + Swift 5.10+ SSWG Incubation Level: Graduated @@ -167,7 +163,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.8]: https://swift.org +[Swift 5.10]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection From 02f74caf7318aa0953bc2eb82cacb43887b66a42 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 15 May 2025 12:03:16 +0200 Subject: [PATCH 276/292] Use psql-13 in integration tests (#557) --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 21271124..704508ba 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -49,13 +49,13 @@ jobs: postgres-image: - postgres:17 - postgres:15 - - postgres:12 + - postgres:13 include: - postgres-image: postgres:17 postgres-auth: scram-sha-256 - postgres-image: postgres:15 postgres-auth: md5 - - postgres-image: postgres:12 + - postgres-image: postgres:13 postgres-auth: trust container: image: swift:6.1-noble From ccb25dcc428587224633a79c0ce0430eeac3dc0f Mon Sep 17 00:00:00 2001 From: Andreas Bauer Date: Wed, 4 Jun 2025 17:23:53 +0200 Subject: [PATCH 277/292] Remove check if TLSConfiguration changed when producing SSLContext (#560) --- .../PostgresNIO/Pool/ConnectionFactory.swift | 68 ++++++------------- 1 file changed, 22 insertions(+), 46 deletions(-) diff --git a/Sources/PostgresNIO/Pool/ConnectionFactory.swift b/Sources/PostgresNIO/Pool/ConnectionFactory.swift index 319b86c4..31343826 100644 --- a/Sources/PostgresNIO/Pool/ConnectionFactory.swift +++ b/Sources/PostgresNIO/Pool/ConnectionFactory.swift @@ -15,9 +15,9 @@ final class ConnectionFactory: Sendable { struct SSLContextCache: Sendable { enum State { case none - case producing(TLSConfiguration, [CheckedContinuation]) - case cached(TLSConfiguration, NIOSSLContext) - case failed(TLSConfiguration, any Error) + case producing([CheckedContinuation]) + case cached(NIOSSLContext) + case failed(any Error) } var state: State = .none @@ -106,34 +106,17 @@ final class ConnectionFactory: Sendable { let action = self.sslContextBox.withLockedValue { cache -> Action in switch cache.state { case .none: - cache.state = .producing(tlsConfiguration, [continuation]) + cache.state = .producing([continuation]) return .produce - case .cached(let cachedTLSConfiguration, let context): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - return .succeed(context) - } else { - cache.state = .producing(tlsConfiguration, [continuation]) - return .produce - } - - case .failed(let cachedTLSConfiguration, let error): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - return .fail(error) - } else { - cache.state = .producing(tlsConfiguration, [continuation]) - return .produce - } - - case .producing(let cachedTLSConfiguration, var continuations): + case .cached(let context): + return .succeed(context) + case .failed(let error): + return .fail(error) + case .producing(var continuations): continuations.append(continuation) - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - cache.state = .producing(cachedTLSConfiguration, continuations) - return .wait - } else { - cache.state = .producing(tlsConfiguration, continuations) - return .produce - } + cache.state = .producing(continuations) + return .wait } } @@ -143,10 +126,7 @@ final class ConnectionFactory: Sendable { case .produce: // TBD: we might want to consider moving this off the concurrent executor - self.reportProduceSSLContextResult( - Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}), - for: tlsConfiguration - ) + self.reportProduceSSLContextResult(Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)})) case .succeed(let context): continuation.resume(returning: context) @@ -157,7 +137,7 @@ final class ConnectionFactory: Sendable { } } - private func reportProduceSSLContextResult(_ result: Result, for tlsConfiguration: TLSConfiguration) { + private func reportProduceSSLContextResult(_ result: Result) { enum Action { case fail(any Error, [CheckedContinuation]) case succeed(NIOSSLContext, [CheckedContinuation]) @@ -172,19 +152,15 @@ final class ConnectionFactory: Sendable { case .cached, .failed: return .none - case .producing(let cachedTLSConfiguration, let continuations): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - switch result { - case .success(let context): - cache.state = .cached(cachedTLSConfiguration, context) - return .succeed(context, continuations) - - case .failure(let failure): - cache.state = .failed(cachedTLSConfiguration, failure) - return .fail(failure, continuations) - } - } else { - return .none + case .producing(let continuations): + switch result { + case .success(let context): + cache.state = .cached(context) + return .succeed(context, continuations) + + case .failure(let failure): + cache.state = .failed(failure) + return .fail(failure, continuations) } } } From f275afc0e3e01f1552d1b1e7ed11d5b13c92e357 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Fri, 13 Jun 2025 11:04:45 +0330 Subject: [PATCH 278/292] `TinyFastSequence` logic fixes (#563) --- Sources/ConnectionPoolModule/TinyFastSequence.swift | 10 ++++++++-- .../TinyFastSequenceTests.swift | 8 ++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/Sources/ConnectionPoolModule/TinyFastSequence.swift b/Sources/ConnectionPoolModule/TinyFastSequence.swift index dff8a30b..df140c98 100644 --- a/Sources/ConnectionPoolModule/TinyFastSequence.swift +++ b/Sources/ConnectionPoolModule/TinyFastSequence.swift @@ -29,6 +29,12 @@ struct TinyFastSequence: Sequence { self.base = .none(reserveCapacity: 0) case 1: self.base = .one(collection.first!, reserveCapacity: 0) + case 2: + self.base = .two( + collection.first!, + collection[collection.index(after: collection.startIndex)], + reserveCapacity: 0 + ) default: if let collection = collection as? Array { self.base = .n(collection) @@ -46,7 +52,7 @@ struct TinyFastSequence: Sequence { case 1: self.base = .one(max2Sequence.first!, reserveCapacity: 0) case 2: - self.base = .n(Array(max2Sequence)) + self.base = .two(max2Sequence.first!, max2Sequence.second!, reserveCapacity: 0) default: fatalError() } @@ -169,7 +175,7 @@ struct TinyFastSequence: Sequence { case .n(let array): if self.index < array.endIndex { - defer { self.index += 1} + defer { self.index += 1 } return array[self.index] } return nil diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index 1a2836b9..602eb799 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -46,11 +46,19 @@ final class TinyFastSequenceTests: XCTestCase { XCTAssertEqual(array.capacity, 8) var twoElemSequence = TinyFastSequence([1, 2]) + twoElemSequence.append(3) twoElemSequence.reserveCapacity(8) guard case .n(let array) = twoElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } XCTAssertEqual(array.capacity, 8) + + var threeElemSequence = TinyFastSequence([1, 2, 3]) + threeElemSequence.reserveCapacity(8) + guard case .n(let array) = threeElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) } func testNewSequenceSlowPath() { From c17db2f1be8ae94f3e8745a0cd227a437898ede1 Mon Sep 17 00:00:00 2001 From: Mahdi Bahrami Date: Fri, 13 Jun 2025 15:41:45 +0330 Subject: [PATCH 279/292] Fix a `TinyFastSequence` test was modified in #563 (#564) --- Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index 602eb799..b2f04544 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -46,8 +46,8 @@ final class TinyFastSequenceTests: XCTestCase { XCTAssertEqual(array.capacity, 8) var twoElemSequence = TinyFastSequence([1, 2]) - twoElemSequence.append(3) twoElemSequence.reserveCapacity(8) + twoElemSequence.append(3) guard case .n(let array) = twoElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } From d50aadeb18cc96509971315b950053aefcc38cc5 Mon Sep 17 00:00:00 2001 From: Niko Dittmar <77522904+nikodittmar@users.noreply.github.com> Date: Mon, 7 Jul 2025 23:14:18 -0700 Subject: [PATCH 280/292] Fix: Correctly decode jsonb to String by stripping version byte (#568) --- .../New/Data/String+PostgresCodable.swift | 6 ++++++ Tests/IntegrationTests/PostgresNIOTests.swift | 16 ++++++++++++++++ .../New/Data/String+PSQLCodableTests.swift | 11 +++++++++++ 3 files changed, 33 insertions(+) diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 6bd09e78..7e8376a7 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -29,6 +29,12 @@ extension String: PostgresDecodable { context: PostgresDecodingContext ) throws { switch (format, type) { + case (.binary, .jsonb): + // Discard the version byte + guard let version = buffer.readInteger(as: UInt8.self), version == 1 else { + throw PostgresDecodingError.Code.failure + } + self = buffer.readString(length: buffer.readableBytes)! case (_, .varchar), (_, .bpchar), (_, .text), diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 9a58f050..5d27e36a 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -930,6 +930,22 @@ final class PostgresNIOTests: XCTestCase { } } + func testJSONBDecodeString() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + do { + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query("select '{\"hello\": \"world\"}'::jsonb as data").wait()) + + var resultString: String? + XCTAssertNoThrow(resultString = try rows?.first?.decode(String.self, context: .default)) + + XCTAssertEqual(resultString, "{\"hello\": \"world\"}") + } + } + func testInt4RangeSerialize() async throws { let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() self.addTeardownBlock { diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index aadeabff..c1843c2a 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -52,4 +52,15 @@ class String_PSQLCodableTests: XCTestCase { XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } + + func testDecodeFromJSONB() { + let json = #"{"hello": "world"}"# + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(1)) + buffer.writeString(json) + + var decoded: String? + XCTAssertNoThrow(decoded = try String(from: &buffer, type: .jsonb, format: .binary, context: .default)) + XCTAssertEqual(decoded, json) + } } From 20a0f2af079049ff498c24b78eb7485fde6aa82d Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Tue, 8 Jul 2025 09:37:57 +0200 Subject: [PATCH 281/292] Add message types to support COPY operations (#569) This adds the infrastrucutre to decode messages needed for COPY operations. It does not implement the handling support itself yet. That will be added in a follow-up PR. --- .../ConnectionStateMachine.swift | 6 + .../ExtendedQueryStateMachine.swift | 8 +- .../New/Messages/CopyInMessage.swift | 44 ++++++ .../New/PostgresBackendMessage.swift | 12 +- .../New/PostgresBackendMessageDecoder.swift | 6 + .../New/PostgresChannelHandler.swift | 2 + .../New/PostgresFrontendMessageEncoder.swift | 25 +++ .../ExtendedQueryStateMachineTests.swift | 2 +- .../PSQLBackendMessageEncoder.swift | 14 ++ .../PSQLFrontendMessageDecoder.swift | 12 ++ .../Extensions/PostgresFrontendMessage.swift | 32 ++++ .../New/Messages/CopyTests.swift | 147 ++++++++++++++++++ 12 files changed, 305 insertions(+), 5 deletions(-) create mode 100644 Sources/PostgresNIO/New/Messages/CopyInMessage.swift create mode 100644 Tests/PostgresNIOTests/New/Messages/CopyTests.swift diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9d264bcc..8560b948 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -752,6 +752,12 @@ struct ConnectionStateMachine { return self.modify(with: action) } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> ConnectionAction { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 087a6c24..5708b6b9 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -91,7 +91,7 @@ struct ExtendedQueryStateMachine { mutating func cancel() -> Action { switch self.state { case .initialized: - preconditionFailure("Start must be called immediatly after the query was created") + preconditionFailure("Start must be called immediately after the query was created") case .messagesSent(let queryContext), .parseCompleteReceived(let queryContext), @@ -322,6 +322,12 @@ struct ExtendedQueryStateMachine { } } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> Action { + return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> Action { guard case .bindCompleteReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) diff --git a/Sources/PostgresNIO/New/Messages/CopyInMessage.swift b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift new file mode 100644 index 00000000..46dec648 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift @@ -0,0 +1,44 @@ +extension PostgresBackendMessage { + struct CopyInResponse: Hashable { + enum Format: Int8 { + case textual = 0 + case binary = 1 + } + + let format: Format + let columnFormats: [Format] + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes) + } + guard let format = Format(rawValue: rawFormat) else { + throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat) + } + + guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes) + } + var columnFormatCodes: [Format] = [] + columnFormatCodes.reserveCapacity(Int(numColumns)) + + for _ in 0.. Self { + return PSQLPartialDecodingError( + description: "Unknown message kind: \(messageID)", + file: file, line: line) + } } extension ByteBuffer { diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 0a14849a..baf801e5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { action = self.state.closeCompletedReceived() case .commandComplete(let commandTag): action = self.state.commandCompletedReceived(commandTag) + case .copyInResponse(let copyInResponse): + action = self.state.copyInResponseReceived(copyInResponse) case .dataRow(let dataRow): action = self.state.dataRowReceived(dataRow) case .emptyQueryResponse: diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 97805418..6ca4cc27 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder { self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) } + /// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data. + /// + /// The caller of this function is expected to write the encoder's message buffer to the backend after calling this + /// function, followed by sending the actual data to the backend. + mutating func copyDataHeader(dataLength: UInt32) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength) + } + + mutating func copyDone() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0) + } + + mutating func copyFail(message: String) { + self.clearIfNeeded() + var messageBuffer = ByteBuffer() + messageBuffer.writeNullTerminatedString(message) + self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes)) + self.buffer.writeImmutableBuffer(messageBuffer) + } + mutating func sync() { self.clearIfNeeded() self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) @@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder { private enum FrontendMessageID: UInt8, Hashable, Sendable { case bind = 66 // B case close = 67 // C + case copyData = 100 // d + case copyDone = 99 // c + case copyFail = 102 // f case describe = 68 // D case execute = 69 // E case flush = 72 // H diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index ae484acc..872664af 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -114,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) } - func testExtendedQueryIsCancelledImmediatly() { + func testExtendedQueryIsCancelledImmediately() { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 9614bf1e..0c6b37ef 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { case .commandComplete(let string): self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer) + case .copyInResponse(let copyInResponse): + self.encode(messageID: message.id, payload: copyInResponse, into: &buffer) case .dataRow(let row): self.encode(messageID: message.id, payload: row, into: &buffer) @@ -99,6 +101,8 @@ extension PostgresBackendMessage { return .closeComplete case .commandComplete: return .commandComplete + case .copyInResponse: + return .copyInResponse case .dataRow: return .dataRow case .emptyQueryResponse: @@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { } } +extension PostgresBackendMessage.CopyInResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int8(self.format.rawValue)) + buffer.writeInteger(Int16(self.columnFormats.count)) + for columnFormat in columnFormats { + buffer.writeInteger(Int16(columnFormat.rawValue)) + } + } +} + extension DataRow: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.columnCount, as: Int16.self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 55ccd0a9..d913da22 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -168,6 +168,18 @@ extension PostgresFrontendMessage { ) ) + case .copyData: + return .copyData(CopyData(data: buffer)) + + case .copyDone: + return .copyDone + + case .copyFail: + guard let message = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .copyFail(CopyFail(message: message)) + case .close: preconditionFailure("TODO: Unimplemented") diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 2532959a..5fc8144b 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -36,6 +36,14 @@ enum PostgresFrontendMessage: Equatable { let secretKey: Int32 } + struct CopyData: Hashable { + var data: ByteBuffer + } + + struct CopyFail: Hashable { + var message: String + } + enum Close: Hashable { case preparedStatement(String) case portal(String) @@ -170,6 +178,9 @@ enum PostgresFrontendMessage: Equatable { case bind(Bind) case cancel(Cancel) + case copyData(CopyData) + case copyDone + case copyFail(CopyFail) case close(Close) case describe(Describe) case execute(Execute) @@ -186,6 +197,9 @@ enum PostgresFrontendMessage: Equatable { enum ID: UInt8, Equatable { case bind + case copyData + case copyDone + case copyFail case close case describe case execute @@ -201,12 +215,18 @@ enum PostgresFrontendMessage: Equatable { switch rawValue { case UInt8(ascii: "B"): self = .bind + case UInt8(ascii: "c"): + self = .copyDone case UInt8(ascii: "C"): self = .close + case UInt8(ascii: "d"): + self = .copyData case UInt8(ascii: "D"): self = .describe case UInt8(ascii: "E"): self = .execute + case UInt8(ascii: "f"): + self = .copyFail case UInt8(ascii: "H"): self = .flush case UInt8(ascii: "P"): @@ -230,6 +250,12 @@ enum PostgresFrontendMessage: Equatable { switch self { case .bind: return UInt8(ascii: "B") + case .copyData: + return UInt8(ascii: "d") + case .copyDone: + return UInt8(ascii: "c") + case .copyFail: + return UInt8(ascii: "f") case .close: return UInt8(ascii: "C") case .describe: @@ -263,6 +289,12 @@ extension PostgresFrontendMessage { return .bind case .cancel: preconditionFailure("Cancel messages don't have an identifier") + case .copyData: + return .copyData + case .copyDone: + return .copyDone + case .copyFail: + return .copyFail case .close: return .close case .describe: diff --git a/Tests/PostgresNIOTests/New/Messages/CopyTests.swift b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift new file mode 100644 index 00000000..de686ae5 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift @@ -0,0 +1,147 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class CopyTests: XCTestCase { + func testDecodeCopyInResponseMessage() throws { + let expected: [PostgresBackendMessage] = [ + .copyInResponse(.init(format: .textual, columnFormats: [.textual, .textual])), + .copyInResponse(.init(format: .binary, columnFormats: [.binary, .binary])), + .copyInResponse(.init(format: .binary, columnFormats: [.textual, .binary])) + ] + + var buffer = ByteBuffer() + + for message in expected { + guard case .copyInResponse(let message) = message else { + return XCTFail("Expected only to get copyInResponse here!") + } + buffer.writeBackendMessage(id: .copyInResponse ) { buffer in + buffer.writeInteger(Int8(message.format.rawValue)) + buffer.writeInteger(Int16(message.columnFormats.count)) + for columnFormat in message.columnFormats { + buffer.writeInteger(UInt16(columnFormat.rawValue)) + } + } + } + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + + func testDecodeFailureBecauseOfEmptyMessage() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { _ in} + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + + func testDecodeFailureBecauseOfInvalidFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfMissingColumnNumber() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + + func testDecodeFailureBecauseOfMissingColumns() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(20)) // 20 columns promised, none given + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfInvalidColumnFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(1)) + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testEncodeCopyDataHeader() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDataHeader(dataLength: 3) + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.copyData.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 7) + } + + func testEncodeCopyDone() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDone() + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.copyDone.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 4) + } + + func testEncodeCopyFail() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyFail(message: "Oh, no :(") + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 15) + XCTAssertEqual(PostgresFrontendMessage.ID.copyFail.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 14) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "Oh, no :(") + } +} From c7e2cda52d0bf984d4a254ed55aa60ffc068f60e Mon Sep 17 00:00:00 2001 From: Joseph Heck Date: Tue, 8 Jul 2025 08:35:29 -0700 Subject: [PATCH 282/292] Updating error type to use computed properties (#559) --- .../ConnectionPoolError.swift | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPoolError.swift b/Sources/ConnectionPoolModule/ConnectionPoolError.swift index 1f1e1d2c..3abfe778 100644 --- a/Sources/ConnectionPoolModule/ConnectionPoolError.swift +++ b/Sources/ConnectionPoolModule/ConnectionPoolError.swift @@ -1,16 +1,25 @@ public struct ConnectionPoolError: Error, Hashable { - enum Base: Error, Hashable { + @usableFromInline + enum Base: Error, Hashable, Sendable { case requestCancelled case poolShutdown } - private let base: Base + @usableFromInline + let base: Base + @inlinable init(_ base: Base) { self.base = base } /// The connection requests got cancelled - public static let requestCancelled = ConnectionPoolError(.requestCancelled) + @inlinable + public static var requestCancelled: Self { + ConnectionPoolError(.requestCancelled) + } /// The connection requests can't be fulfilled as the pool has already been shutdown - public static let poolShutdown = ConnectionPoolError(.poolShutdown) + @inlinable + public static var poolShutdown: Self { + ConnectionPoolError(.poolShutdown) + } } From ca70d8cd4d50509afefb8a791144ff126600df3e Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 22 Jul 2025 12:44:45 +0200 Subject: [PATCH 283/292] Use `ConnectionLease` to return connection (#571) --- .../ConnectionLease.swift | 17 ++++ .../ConnectionPoolModule/ConnectionPool.swift | 7 +- .../ConnectionRequest.swift | 16 ++-- .../ConnectionPoolTestUtils/MockRequest.swift | 8 +- Sources/PostgresNIO/Pool/PostgresClient.swift | 28 +++--- .../ConnectionPoolTests.swift | 92 +++++++++---------- .../ConnectionRequestTests.swift | 7 +- .../PoolStateMachine+RequestQueueTests.swift | 26 +++--- .../PoolStateMachineTests.swift | 26 +++--- .../PostgresClientTests.swift | 2 +- 10 files changed, 124 insertions(+), 105 deletions(-) create mode 100644 Sources/ConnectionPoolModule/ConnectionLease.swift diff --git a/Sources/ConnectionPoolModule/ConnectionLease.swift b/Sources/ConnectionPoolModule/ConnectionLease.swift new file mode 100644 index 00000000..77591a58 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionLease.swift @@ -0,0 +1,17 @@ +public struct ConnectionLease: Sendable { + public var connection: Connection + + @usableFromInline + let _release: @Sendable (Connection) -> () + + @inlinable + public init(connection: Connection, release: @escaping @Sendable (Connection) -> Void) { + self.connection = connection + self._release = release + } + + @inlinable + public func release() { + self._release(self.connection) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index b460b263..ee72337d 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -88,7 +88,7 @@ public protocol ConnectionRequestProtocol: Sendable { /// A function that is called with a connection or a /// `PoolError`. - func complete(with: Result) + func complete(with: Result, ConnectionPoolError>) } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @@ -402,8 +402,11 @@ public final class ConnectionPool< /*private*/ func runRequestAction(_ action: StateMachine.RequestAction) { switch action { case .leaseConnection(let requests, let connection): + let lease = ConnectionLease(connection: connection) { connection in + self.releaseConnection(connection) + } for request in requests { - request.complete(with: .success(connection)) + request.complete(with: .success(lease)) } case .failRequest(let request, let error): diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 1d1c55da..d6654a27 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -5,18 +5,18 @@ public struct ConnectionRequest: ConnectionRequest public var id: ID @usableFromInline - private(set) var continuation: CheckedContinuation + private(set) var continuation: CheckedContinuation, any Error> @inlinable init( id: Int, - continuation: CheckedContinuation + continuation: CheckedContinuation, any Error> ) { self.id = id self.continuation = continuation } - public func complete(with result: Result) { + public func complete(with result: Result, ConnectionPoolError>) { self.continuation.resume(with: result) } } @@ -46,7 +46,7 @@ extension ConnectionPool where Request == ConnectionRequest { } @inlinable - public func leaseConnection() async throws -> Connection { + public func leaseConnection() async throws -> ConnectionLease { let requestID = requestIDGenerator.next() let connection = try await withTaskCancellationHandler { @@ -54,7 +54,7 @@ extension ConnectionPool where Request == ConnectionRequest { throw CancellationError() } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, Error>) in let request = Request( id: requestID, continuation: continuation @@ -71,8 +71,8 @@ extension ConnectionPool where Request == ConnectionRequest { @inlinable public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() - defer { self.releaseConnection(connection) } - return try await closure(connection) + let lease = try await self.leaseConnection() + defer { lease.release() } + return try await closure(lease.connection) } } diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift index 5e4e2fc0..3dd8b0fb 100644 --- a/Sources/ConnectionPoolTestUtils/MockRequest.swift +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -1,8 +1,6 @@ import _ConnectionPoolModule -public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { - public typealias Connection = MockConnection - +public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { public struct ID: Hashable, Sendable { var objectID: ObjectIdentifier @@ -11,7 +9,7 @@ public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { } } - public init() {} + public init(connectionType: Connection.Type = Connection.self) {} public var id: ID { ID(self) } @@ -23,7 +21,7 @@ public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { hasher.combine(self.id) } - public func complete(with: Result) { + public func complete(with: Result, ConnectionPoolError>) { } } diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index d54e34eb..0279be07 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -301,11 +301,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// - Returns: The closure's return value. @_disfavoredOverload public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() - defer { self.pool.releaseConnection(connection) } + defer { lease.release() } - return try await closure(connection) + return try await closure(lease.connection) } #if compiler(>=6.0) @@ -319,11 +319,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED // https://github.com/swiftlang/swift/issues/79285 _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() - defer { self.pool.releaseConnection(connection) } + defer { lease.release() } - return try await closure(connection) + return try await closure(lease.connection) } /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. @@ -404,7 +404,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) } - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection var logger = logger logger[postgresMetadataKey: .connectionID] = "\(connection.id)" @@ -419,12 +420,12 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(HandlerTask.extendedQuery(context), promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult.map { $0.asyncSequence(onFinish: { - self.pool.releaseConnection(connection) + lease.release() }) }.get() } catch var error as PSQLError { @@ -446,7 +447,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let logger = logger ?? Self.loggingDisabled do { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( @@ -460,11 +462,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(task, promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult - .map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) } + .map { $0.asyncSequence(onFinish: { lease.release() }) } .get() .map { try preparedStatement.decodeRow($0) } } catch var error as PSQLError { @@ -504,7 +506,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { // MARK: - Private Methods - - private func leaseConnection() async throws -> PostgresConnection { + private func leaseConnection() async throws -> ConnectionLease { if !self.runningAtomic.load(ordering: .relaxed) { self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.") } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index c745b4a0..c1ba89cb 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -39,15 +39,13 @@ final class ConnectionPoolTests: XCTestCase { do { for _ in 0..<1000 { async let connectionFuture = try await pool.leaseConnection() - var leasedConnection: MockConnection? + var connectionLease: ConnectionLease? XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) - leasedConnection = try await connectionFuture - XCTAssertNotNil(leasedConnection) - XCTAssert(createdConnection === leasedConnection) + connectionLease = try await connectionFuture + XCTAssertNotNil(connectionLease) + XCTAssert(createdConnection === connectionLease?.connection) - if let leasedConnection { - pool.releaseConnection(leasedConnection) - } + connectionLease?.release() } } catch { XCTFail("Unexpected error: \(error)") @@ -195,8 +193,8 @@ final class ConnectionPoolTests: XCTestCase { for _ in 0..]() for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that we got 4 distinct connections - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 4) + XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 4) // release all 4 leased connections - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } // shutdown @@ -727,7 +725,7 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests let requests = (0..<10).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLeases = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 10 @@ -735,15 +733,15 @@ final class ConnectionPoolTests: XCTestCase { for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 1) // release all 10 leased streams - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } for _ in 0..<9 { @@ -792,41 +790,41 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests var requests = (0..<21).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLease = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 1 } - let connection = try await requests.first!.future.success - connections.append(connection) + let lease = try await requests.first!.future.success + connectionLease.append(lease) requests.removeFirst() - pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) + pool.connectionReceivedNewMaxStreamSetting(lease.connection, newMaxStreamSetting: 21) for (_, request) in requests.enumerated() { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + XCTAssertEqual(Set(connectionLease.lazy.map(\.connection.id)).count, 1) requests = (22..<42).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) // release all 21 leased streams in a single call - pool.releaseConnection(connection, streams: 21) + pool.releaseConnection(lease.connection, streams: 21) // ensure all 20 new requests got fulfilled for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // release all 20 leased streams one by one for _ in requests { - pool.releaseConnection(connection, streams: 1) + pool.releaseConnection(lease.connection, streams: 1) } // shutdown @@ -840,14 +838,14 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionFuture: ConnectionRequestProtocol { let id: Int - let future: Future + let future: Future> init(id: Int) { self.id = id - self.future = Future(of: MockConnection.self) + self.future = Future(of: ConnectionLease.self) } - func complete(with result: Result) { + func complete(with result: Result, ConnectionPoolError>) { switch result { case .success(let success): self.future.yield(value: success) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index 537efbd9..2952bf8b 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -6,13 +6,14 @@ final class ConnectionRequestTests: XCTestCase { func testHappyPath() async throws { let mockConnection = MockConnection(id: 1) - let connection = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let lease = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, any Error>) in let request = ConnectionRequest(id: 42, continuation: continuation) XCTAssertEqual(request.id, 42) - continuation.resume(with: .success(mockConnection)) + let lease = ConnectionLease(connection: mockConnection) { _ in } + continuation.resume(with: .success(lease)) } - XCTAssert(connection === mockConnection) + XCTAssert(lease.connection === mockConnection) } func testSadPath() async throws { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index b74b86cc..ddd6a71e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -11,7 +11,7 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) XCTAssertEqual(queue.count, 1) XCTAssertFalse(queue.isEmpty) @@ -25,11 +25,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -49,11 +49,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -76,11 +76,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -113,11 +113,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index c0b6ddcd..08afdf8e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -7,8 +7,8 @@ typealias TestPoolStateMachine = PoolStateMachine< MockConnection, ConnectionIDGenerator, MockConnection.ID, - MockRequest, - MockRequest.ID, + MockRequest, + MockRequest.ID, MockTimerCancellationToken > @@ -75,7 +75,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) @@ -84,13 +84,13 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) // lease connection 1 - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) // request connection while none is available - let request3 = MockRequest() + let request3 = MockRequest(connectionType: MockConnection.self) let leaseRequest3 = stateMachine.leaseConnection(request3) XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest3.request, .none) @@ -132,7 +132,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -144,7 +144,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .none) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -195,13 +195,13 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -245,7 +245,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -287,7 +287,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -309,7 +309,7 @@ final class PoolStateMachineTests: XCTestCase { connection1.closeIfClosing() // request connection while none exists anymore - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -354,7 +354,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 1) // one connection should exist - let request = MockRequest() + let request = MockRequest(connectionType: MockConnection.self) let leaseRequest = stateMachine.leaseConnection(request) XCTAssertEqual(leaseRequest.connection, .none) XCTAssertEqual(leaseRequest.request, .none) diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index eaf3663f..9ac92754 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -338,7 +338,7 @@ final class PostgresClientTests: XCTestCase { ) var count = 0 - for try await (id, label) in rows.decode((Int, String).self) { + for try await _ in rows.decode((Int, String).self) { count += 1 } XCTAssertEqual(count, 1) From 8ee6118c03501196be183b0938d2ec4478c18954 Mon Sep 17 00:00:00 2001 From: Andreas Bauer Date: Wed, 13 Aug 2025 17:32:39 +0200 Subject: [PATCH 284/292] PostgresClient: Log connection failed events at info level (#575) --- Sources/PostgresNIO/Pool/PostgresClientMetrics.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift index aa8215db..62fa326a 100644 --- a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift +++ b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift @@ -19,7 +19,7 @@ final class PostgresClientMetrics: ConnectionPoolObservabilityDelegate { /// A connection attempt failed with the given error. After some period of /// time ``startedConnecting(id:)`` may be called again. func connectFailed(id: ConnectionID, error: Error) { - self.logger.debug("Connection creation failed", metadata: [ + self.logger.info("Connection creation failed", metadata: [ .connectionID: "\(id)", .error: "\(String(reflecting: error))" ]) From ce59628fbe432b7a98c315b1199e8ba5a5b4a4ae Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Wed, 27 Aug 2025 10:07:49 -0500 Subject: [PATCH 285/292] Fix several CI issues (#579) * Fix several CI issues - Update the list of versions for Swift images - Update actions/checkout to v5 - Disable TSan when building with Swift 5.10 and 6.0 due to swift-crypto hitting a compiler crasher - Add macOS Sequoia tests - Remove Homebrew bug workaround (no longer needed) * Fix macOS test matrix * Tests can't safely make assumptions about stdlib capacity behavior --- .github/workflows/test.yml | 57 +++++++------------ .../TinyFastSequenceTests.swift | 8 +-- 2 files changed, 25 insertions(+), 40 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 704508ba..1eaf9a87 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,7 +21,8 @@ jobs: - swift:5.10-jammy - swift:6.0-noble - swift:6.1-noble - - swiftlang/swift:nightly-main-jammy + - swiftlang/swift:nightly-6.2-noble + - swiftlang/swift:nightly-main-noble container: ${{ matrix.swift-image }} runs-on: ubuntu-latest steps: @@ -32,11 +33,16 @@ jobs: [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" swift --version + - name: Install curl for Codecov + run: apt-get update -y -q && apt-get install -y curl - name: Check out package - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Run unit tests with Thread Sanitizer + shell: bash run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread --enable-code-coverage + # https://github.com/swiftlang/swift/issues/74042 was never fixed in 5.10 and swift-crypto hits it in 6.0 as well + SANITIZE="$([[ "${SWIFT_VERSION}" =~ ^swift-(5|6\.0) ]] || echo '--sanitize=thread')" + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' ${SANITIZE} --enable-code-coverage - name: Submit code coverage uses: vapor/swift-codecov-action@v0.3 with: @@ -103,15 +109,15 @@ jobs: [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { path: 'postgres-nio' } - name: Run integration tests run: swift test --package-path postgres-nio --filter=^IntegrationTests - name: Check out postgres-kit dependent - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } - name: Check out fluent-postgres-driver dependent - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { repository: 'vapor/fluent-postgres-driver', path: 'fluent-postgres-driver' } - name: Use local package in dependents run: | @@ -132,11 +138,14 @@ jobs: postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 - xcode-version: - - '~15' + macos-version: + - 'macos-14' + - 'macos-15' include: - - xcode-version: '~15' - macos-version: 'macos-14' + - macos-version: 'macos-14' + xcode-version: 'latest-stable' + - macos-version: 'macos-15' + xcode-version: 'latest-stable' runs-on: ${{ matrix.macos-version }} env: POSTGRES_HOSTNAME: 127.0.0.1 @@ -154,18 +163,13 @@ jobs: - name: Install Postgres, setup DB and auth, and wait for server start run: | export PATH="$(brew --prefix)/opt/${POSTGRES_FORMULA}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - # ** BEGIN ** Work around bug in both Homebrew and GHA - (brew upgrade python@3.11 || true) && (brew link --force --overwrite python@3.11 || true) - (brew upgrade python@3.12 || true) && (brew link --force --overwrite python@3.12 || true) - (brew upgrade || true) - # ** END ** Work around bug in both Homebrew and GHA brew install --overwrite "${POSTGRES_FORMULA}" brew link --overwrite --force "${POSTGRES_FORMULA}" initdb --locale=C --auth-host "${POSTGRES_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait timeout-minutes: 15 - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Run all tests run: swift test @@ -175,7 +179,7 @@ jobs: container: swift:noble steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 # https://github.com/actions/checkout/issues/766 @@ -183,22 +187,3 @@ jobs: run: | git config --global --add safe.directory "${GITHUB_WORKSPACE}" swift package diagnose-api-breaking-changes origin/main - -# gh-codeql: -# if: ${{ false }} -# runs-on: ubuntu-latest -# container: swift:noble -# permissions: { actions: write, contents: read, security-events: write } -# steps: -# - name: Check out code -# uses: actions/checkout@v4 -# - name: Mark repo safe in non-fake global config -# run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" -# - name: Initialize CodeQL -# uses: github/codeql-action/init@v3 -# with: -# languages: swift -# - name: Perform build -# run: swift build -# - name: Run CodeQL analyze -# uses: github/codeql-action/analyze@v3 diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index b2f04544..6be22005 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -34,7 +34,7 @@ final class TinyFastSequenceTests: XCTestCase { guard case .n(let array) = emptySequence.base else { return XCTFail("Expected sequence to be backed by an array") } - XCTAssertEqual(array.capacity, 8) + XCTAssertGreaterThanOrEqual(array.capacity, 8) var oneElemSequence = TinyFastSequence(element: 1) oneElemSequence.reserveCapacity(8) @@ -43,7 +43,7 @@ final class TinyFastSequenceTests: XCTestCase { guard case .n(let array) = oneElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } - XCTAssertEqual(array.capacity, 8) + XCTAssertGreaterThanOrEqual(array.capacity, 8) var twoElemSequence = TinyFastSequence([1, 2]) twoElemSequence.reserveCapacity(8) @@ -51,14 +51,14 @@ final class TinyFastSequenceTests: XCTestCase { guard case .n(let array) = twoElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } - XCTAssertEqual(array.capacity, 8) + XCTAssertGreaterThanOrEqual(array.capacity, 8) var threeElemSequence = TinyFastSequence([1, 2, 3]) threeElemSequence.reserveCapacity(8) guard case .n(let array) = threeElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } - XCTAssertEqual(array.capacity, 8) + XCTAssertGreaterThanOrEqual(array.capacity, 8) } func testNewSequenceSlowPath() { From 7a3d19d2692c2db870fdb9feb17e173991af1609 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Wed, 17 Sep 2025 04:50:32 -0500 Subject: [PATCH 286/292] Don't log misleading errors when the client closes a connection (#585) --- Sources/PostgresNIO/New/PostgresChannelHandler.swift | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index baf801e5..bc256203 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -567,8 +567,13 @@ final class PostgresChannelHandler: ChannelDuplexHandler { _ cleanup: ConnectionStateMachine.ConnectionAction.CleanUpContext, context: ChannelHandlerContext ) { - self.logger.debug("Cleaning up and closing connection.", metadata: [.error: "\(cleanup.error)"]) - + // Don't log a misleading error if the client closed the connection. + if cleanup.error.code == .clientClosedConnection { + self.logger.debug("Cleaning up and closing connection.") + } else { + self.logger.debug("Cleaning up and closing connection.", metadata: [.error: "\(cleanup.error)"]) + } + // 1. fail all tasks cleanup.tasks.forEach { task in task.failWithError(cleanup.error) From 6684f15e382b6dcc84915d42392ef3909d1a39fa Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 17 Sep 2025 15:45:00 +0200 Subject: [PATCH 287/292] Drop Swift 5.10 (#586) --- .github/workflows/test.yml | 7 +-- Package.swift | 10 +++- README.md | 4 +- .../Connection/PostgresConnection.swift | 57 +------------------ .../PostgresNIO/New/VariadicGenerics.swift | 12 ++-- Sources/PostgresNIO/Pool/PostgresClient.swift | 44 +++----------- 6 files changed, 30 insertions(+), 104 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1eaf9a87..926f2fbe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,10 +18,9 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.10-jammy - - swift:6.0-noble + - swift:6.0-jammy - swift:6.1-noble - - swiftlang/swift:nightly-6.2-noble + - swift:6.2-noble - swiftlang/swift:nightly-main-noble container: ${{ matrix.swift-image }} runs-on: ubuntu-latest @@ -64,7 +63,7 @@ jobs: - postgres-image: postgres:13 postgres-auth: trust container: - image: swift:6.1-noble + image: swift:6.2-noble volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: diff --git a/Package.swift b/Package.swift index af2d07ae..673a4bb2 100644 --- a/Package.swift +++ b/Package.swift @@ -1,9 +1,15 @@ -// swift-tools-version:5.10 +// swift-tools-version:6.0 import PackageDescription +#if compiler(>=6.1) +let swiftSettings: [SwiftSetting] = [] +#else let swiftSettings: [SwiftSetting] = [ - .enableUpcomingFeature("StrictConcurrency"), + // Sadly the 6.0 compiler concurrency checker finds false positives. + // To be able to compile, lets reduce the language version down to 5 for 6.0 only. + .swiftLanguageMode(.v5) ] +#endif let package = Package( name: "postgres-nio", diff --git a/README.md b/README.md index 6d03b8da..fa4495e2 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Continuous Integration - Swift 5.10+ + Swift 6.0+ SSWG Incubation Level: Graduated @@ -163,7 +163,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.10]: https://swift.org +[Swift 6.0]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index e267d8f9..fc48fa31 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -531,7 +531,6 @@ extension PostgresConnection { } } - #if compiler(>=6.0) /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. /// /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then @@ -552,9 +551,8 @@ extension PostgresConnection { file: String = #file, line: Int = #line, isolation: isolated (any Actor)? = #isolation, - // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED - // https://github.com/swiftlang/swift/issues/79285 - _ process: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + _ process: (PostgresConnection) async throws -> sending Result + ) async throws -> sending Result { do { try await self.query("BEGIN;", logger: logger) } catch { @@ -583,57 +581,6 @@ extension PostgresConnection { throw transactionError } } - #else - /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. - /// - /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then - /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user - /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` - /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the - /// changes made within the closure. - /// - /// - Parameters: - /// - logger: The `Logger` to log into for the transaction. - /// - file: The file, the transaction was started in. Used for better error reporting. - /// - line: The line, the transaction was started in. Used for better error reporting. - /// - closure: The user provided code to modify the database. Use the provided connection to run queries. - /// The connection must stay in the transaction mode. Otherwise this method will throw! - /// - Returns: The closure's return value. - public func withTransaction( - logger: Logger, - file: String = #file, - line: Int = #line, - _ process: (PostgresConnection) async throws -> Result - ) async throws -> Result { - do { - try await self.query("BEGIN;", logger: logger) - } catch { - throw PostgresTransactionError(file: file, line: line, beginError: error) - } - - var closureHasFinished: Bool = false - do { - let value = try await process(self) - closureHasFinished = true - try await self.query("COMMIT;", logger: logger) - return value - } catch { - var transactionError = PostgresTransactionError(file: file, line: line) - if !closureHasFinished { - transactionError.closureError = error - do { - try await self.query("ROLLBACK;", logger: logger) - } catch { - transactionError.rollbackError = error - } - } else { - transactionError.commitError = error - } - - throw transactionError - } - } - #endif } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/New/VariadicGenerics.swift b/Sources/PostgresNIO/New/VariadicGenerics.swift index 7931c90c..b284c7a2 100644 --- a/Sources/PostgresNIO/New/VariadicGenerics.swift +++ b/Sources/PostgresNIO/New/VariadicGenerics.swift @@ -116,7 +116,8 @@ extension PostgresRow { extension AsyncSequence where Element == PostgresRow { // --- snip TODO: Remove once bug is fixed, that disallows tuples of one @inlinable - public func decode( + @preconcurrency + public func decode( _: Column.Type, context: PostgresDecodingContext, file: String = #fileID, @@ -128,7 +129,8 @@ extension AsyncSequence where Element == PostgresRow { } @inlinable - public func decode( + @preconcurrency + public func decode( _: Column.Type, file: String = #fileID, line: Int = #line @@ -137,7 +139,8 @@ extension AsyncSequence where Element == PostgresRow { } // --- snap TODO: Remove once bug is fixed, that disallows tuples of one - public func decode( + @preconcurrency + public func decode( _ columnType: (repeat each Column).Type, context: PostgresDecodingContext, file: String = #fileID, @@ -148,7 +151,8 @@ extension AsyncSequence where Element == PostgresRow { } } - public func decode( + @preconcurrency + public func decode( _ columnType: (repeat each Column).Type, file: String = #fileID, line: Int = #line diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 0279be07..581b5113 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -308,7 +308,6 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return try await closure(lease.connection) } - #if compiler(>=6.0) /// Lease a connection for the provided `closure`'s lifetime. /// /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture @@ -316,9 +315,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// - Returns: The closure's return value. public func withConnection( isolation: isolated (any Actor)? = #isolation, - // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED - // https://github.com/swiftlang/swift/issues/79285 - _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + _ closure: (PostgresConnection) async throws -> sending Result + ) async throws -> sending Result { let lease = try await self.leaseConnection() defer { lease.release() } @@ -346,41 +344,13 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { file: String = #file, line: Int = #line, isolation: isolated (any Actor)? = #isolation, - // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED - // https://github.com/swiftlang/swift/issues/79285 - _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { - try await self.withConnection { connection in - try await connection.withTransaction(logger: logger, file: file, line: line, closure) + _ closure: (PostgresConnection) async throws -> sending Result + ) async throws -> sending Result { + // for 6.0 to compile we need to explicitly forward the isolation. + try await self.withConnection(isolation: isolation) { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, isolation: isolation, closure) } } - #else - - /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. - /// - /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` - /// query on the leased connection against the database. It then lends the connection to the user provided closure. - /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function - /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure - /// throws an error, the function will attempt to rollback the changes made within the closure. - /// - /// - Parameters: - /// - logger: The `Logger` to log into for the transaction. - /// - file: The file, the transaction was started in. Used for better error reporting. - /// - line: The line, the transaction was started in. Used for better error reporting. - /// - closure: The user provided code to modify the database. Use the provided connection to run queries. - /// The connection must stay in the transaction mode. Otherwise this method will throw! - /// - Returns: The closure's return value. - public func withTransaction( - logger: Logger, - file: String = #file, - line: Int = #line, - _ closure: (PostgresConnection) async throws -> Result - ) async throws -> Result { - try await self.withConnection { connection in - try await connection.withTransaction(logger: logger, file: file, line: line, closure) - } - } - #endif /// Run a query on the Postgres server the client is connected to. /// From baf4dbf8bfae6d8b73cbf98b4e44594683767d00 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 18 Sep 2025 11:33:51 +0200 Subject: [PATCH 288/292] Adopt Swift Testing in ConnectionPoolTests (#588) --- .../ConnectionIDGeneratorTests.swift | 15 +- .../ConnectionPoolTests.swift | 136 ++++---- .../ConnectionRequestTests.swift | 16 +- .../Max2SequenceTests.swift | 71 ++--- .../NoKeepAliveBehaviorTests.swift | 11 +- ...oolStateMachine+ConnectionGroupTests.swift | 291 +++++++++--------- ...oolStateMachine+ConnectionStateTests.swift | 233 +++++++------- .../PoolStateMachine+RequestQueueTests.swift | 130 ++++---- .../PoolStateMachineTests.swift | 200 ++++++------ .../TinyFastSequenceTests.swift | 76 ++--- 10 files changed, 623 insertions(+), 556 deletions(-) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift index fb0bfce1..23165746 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionIDGeneratorTests.swift @@ -1,13 +1,14 @@ import _ConnectionPoolModule -import XCTest +import Testing -final class ConnectionIDGeneratorTests: XCTestCase { - func testGenerateConnectionIDs() async { +@Suite struct ConnectionIDGeneratorTests { + + @Test func testGenerateConnectionIDs() async { let idGenerator = ConnectionIDGenerator() - XCTAssertEqual(idGenerator.next(), 0) - XCTAssertEqual(idGenerator.next(), 1) - XCTAssertEqual(idGenerator.next(), 2) + #expect(idGenerator.next() == 0) + #expect(idGenerator.next() == 1) + #expect(idGenerator.next() == 2) await withTaskGroup(of: Void.self) { taskGroup in for _ in 0..<1000 { @@ -17,6 +18,6 @@ final class ConnectionIDGeneratorTests: XCTestCase { } } - XCTAssertEqual(idGenerator.next(), 1003) + #expect(idGenerator.next() == 1003) } } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index c1ba89cb..f3664242 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -2,12 +2,13 @@ import _ConnectionPoolTestUtils import Atomics import NIOEmbedded -import XCTest +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class ConnectionPoolTests: XCTestCase { - func test1000ConsecutiveRequestsOnSingleConnection() async { +@Suite struct ConnectionPoolTests { + + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func test1000ConsecutiveRequestsOnSingleConnection() async { let factory = MockConnectionFactory() var config = ConnectionPoolConfiguration() @@ -34,35 +35,35 @@ final class ConnectionPoolTests: XCTestCase { let createdConnection = await factory.nextConnectAttempt { _ in return 1 } - XCTAssertNotNil(createdConnection) do { for _ in 0..<1000 { - async let connectionFuture = try await pool.leaseConnection() + async let connectionFuture = pool.leaseConnection() var connectionLease: ConnectionLease? - XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + #expect(factory.pendingConnectionAttemptsCount == 0) connectionLease = try await connectionFuture - XCTAssertNotNil(connectionLease) - XCTAssert(createdConnection === connectionLease?.connection) + #expect(connectionLease != nil) + #expect(createdConnection === connectionLease?.connection) connectionLease?.release() } } catch { - XCTFail("Unexpected error: \(error)") + Issue.record("Unexpected error: \(error)") } taskGroup.cancelAll() - XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) + #expect(factory.pendingConnectionAttemptsCount == 0) for connection in factory.runningConnections { connection.closeIfClosing() } } - XCTAssertEqual(factory.runningConnections.count, 0) + #expect(factory.runningConnections.count == 0) } - func testShutdownPoolWhileConnectionIsBeingCreated() async { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testShutdownPoolWhileConnectionIsBeingCreated() async { let clock = MockClock() let factory = MockConnectionFactory() @@ -107,7 +108,8 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionCreationError: Error {} } - func testShutdownPoolWhileConnectionIsBackingOff() async { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testShutdownPoolWhileConnectionIsBackingOff() async { let clock = MockClock() let factory = MockConnectionFactory() @@ -142,7 +144,8 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionCreationError: Error {} } - func testConnectionHardLimitIsRespected() async { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testConnectionHardLimitIsRespected() async { let factory = MockConnectionFactory() var mutableConfig = ConnectionPoolConfiguration() @@ -171,21 +174,21 @@ final class ConnectionPoolTests: XCTestCase { await withTaskGroup(of: Void.self) { taskGroup in taskGroup.addTask_ { await pool.run() - XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original) + #expect(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original == false) } taskGroup.addTask_ { var usedConnectionIDs = Set() for _ in 0..() let keepAliveDuration = Duration.seconds(30) @@ -255,7 +259,7 @@ final class ConnectionPoolTests: XCTestCase { } let connectionLease = try await connectionLeaseFuture - XCTAssert(connection === connectionLease.connection) + #expect(connection === connectionLease.connection) connectionLease.release() @@ -265,11 +269,11 @@ final class ConnectionPoolTests: XCTestCase { var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] let deadline1 = await clock.nextTimerScheduled() print(deadline1) - XCTAssertNotNil(expectedInstants.remove(deadline1)) + #expect(expectedInstants.remove(deadline1) != nil) let deadline2 = await clock.nextTimerScheduled() print(deadline2) - XCTAssertNotNil(expectedInstants.remove(deadline2)) - XCTAssert(expectedInstants.isEmpty) + #expect(expectedInstants.remove(deadline2) != nil) + #expect(expectedInstants.isEmpty == true) // move clock forward to keep alive let newTime = clock.now.advanced(by: keepAliveDuration) @@ -278,14 +282,14 @@ final class ConnectionPoolTests: XCTestCase { await keepAlive.nextKeepAlive { keepAliveConnection in defer { print("keep alive 1 has run") } - XCTAssertTrue(keepAliveConnection === connectionLease.connection) + #expect(keepAliveConnection === connectionLease.connection) return true } // keep alive 2 let deadline3 = await clock.nextTimerScheduled() - XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + #expect(deadline3 == clock.now.advanced(by: keepAliveDuration)) print(deadline3) // race keep alive vs timeout @@ -299,7 +303,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testKeepAliveOnClose() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testKeepAliveOnClose() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(20) @@ -334,7 +339,7 @@ final class ConnectionPoolTests: XCTestCase { } let connectionLease = try await connectionLeaseFuture - XCTAssert(connection === connectionLease.connection) + #expect(connection === connectionLease.connection) connectionLease.release() @@ -344,38 +349,38 @@ final class ConnectionPoolTests: XCTestCase { var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] let deadline1 = await clock.nextTimerScheduled() print(deadline1) - XCTAssertNotNil(expectedInstants.remove(deadline1)) + #expect(expectedInstants.remove(deadline1) != nil) let deadline2 = await clock.nextTimerScheduled() print(deadline2) - XCTAssertNotNil(expectedInstants.remove(deadline2)) - XCTAssert(expectedInstants.isEmpty) + #expect(expectedInstants.remove(deadline2) != nil) + #expect(expectedInstants.isEmpty) // move clock forward to keep alive let newTime = clock.now.advanced(by: keepAliveDuration) clock.advance(to: newTime) await keepAlive.nextKeepAlive { keepAliveConnection in - XCTAssertTrue(keepAliveConnection === connectionLease.connection) + #expect(keepAliveConnection === connectionLease.connection) return true } // keep alive 2 let deadline3 = await clock.nextTimerScheduled() - XCTAssertEqual(deadline3, clock.now.advanced(by: keepAliveDuration)) + #expect(deadline3 == clock.now.advanced(by: keepAliveDuration)) clock.advance(to: clock.now.advanced(by: keepAliveDuration)) let failingKeepAliveDidRun = ManagedAtomic(false) // the following keep alive should not cause a crash _ = try? await keepAlive.nextKeepAlive { keepAliveConnection in defer { - XCTAssertFalse(failingKeepAliveDidRun - .compareExchange(expected: false, desired: true, ordering: .relaxed).original) + #expect(failingKeepAliveDidRun + .compareExchange(expected: false, desired: true, ordering: .relaxed).original == false) } - XCTAssertTrue(keepAliveConnection === connectionLease.connection) + #expect(keepAliveConnection === connectionLease.connection) keepAliveConnection.close() throw CancellationError() // any error } // will fail and it's expected - XCTAssertTrue(failingKeepAliveDidRun.load(ordering: .relaxed)) + #expect(failingKeepAliveDidRun.load(ordering: .relaxed) == true) taskGroup.cancelAll() @@ -385,7 +390,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testKeepAliveWorksRacesAgainstShutdown() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testKeepAliveWorksRacesAgainstShutdown() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -420,7 +426,7 @@ final class ConnectionPoolTests: XCTestCase { } let connectionLease = try await connectionLeaseFuture - XCTAssert(connection === connectionLease.connection) + #expect(connection === connectionLease.connection) connectionLease.release() @@ -430,17 +436,17 @@ final class ConnectionPoolTests: XCTestCase { var expectedInstants: Set = [.init(keepAliveDuration), .init(config.idleTimeout)] let deadline1 = await clock.nextTimerScheduled() print(deadline1) - XCTAssertNotNil(expectedInstants.remove(deadline1)) + #expect(expectedInstants.remove(deadline1) != nil) let deadline2 = await clock.nextTimerScheduled() print(deadline2) - XCTAssertNotNil(expectedInstants.remove(deadline2)) - XCTAssert(expectedInstants.isEmpty) + #expect(expectedInstants.remove(deadline2) != nil) + #expect(expectedInstants.isEmpty) clock.advance(to: clock.now.advanced(by: keepAliveDuration)) await keepAlive.nextKeepAlive { keepAliveConnection in defer { print("keep alive 1 has run") } - XCTAssertTrue(keepAliveConnection === connectionLease.connection) + #expect(keepAliveConnection === connectionLease.connection) return true } @@ -453,7 +459,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testCancelConnectionRequestWorks() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testCancelConnectionRequestWorks() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -501,9 +508,9 @@ final class ConnectionPoolTests: XCTestCase { let taskResult = await leaseTask.result switch taskResult { case .success: - XCTFail("Expected task failure") + Issue.record("Expected task failure") case .failure(let failure): - XCTAssertEqual(failure as? ConnectionPoolError, .requestCancelled) + #expect(failure as? ConnectionPoolError == .requestCancelled) } taskGroup.cancelAll() @@ -513,7 +520,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testLeasingMultipleConnectionsAtOnceWorks() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeasingMultipleConnectionsAtOnceWorks() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -562,7 +570,7 @@ final class ConnectionPoolTests: XCTestCase { } // Ensure that we got 4 distinct connections - XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 4) + #expect(Set(connectionLeases.lazy.map(\.connection.id)).count == 4) // release all 4 leased connections for lease in connectionLeases { @@ -577,7 +585,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testLeasingConnectionAfterShutdownIsInvokedFails() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeasingConnectionAfterShutdownIsInvokedFails() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -618,10 +627,10 @@ final class ConnectionPoolTests: XCTestCase { do { _ = try await pool.leaseConnection() - XCTFail("Expected a failure") + Issue.record("Expected a failure") } catch { print("failed") - XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + #expect(error as? ConnectionPoolError == .poolShutdown) } print("will close connections: \(factory.runningConnections)") @@ -632,7 +641,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testLeasingConnectionsAfterShutdownIsInvokedFails() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeasingConnectionsAfterShutdownIsInvokedFails() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -680,9 +690,9 @@ final class ConnectionPoolTests: XCTestCase { for request in requests { do { _ = try await request.future.success - XCTFail("Expected a failure") + Issue.record("Expected a failure") } catch { - XCTAssertEqual(error as? ConnectionPoolError, .poolShutdown) + #expect(error as? ConnectionPoolError == .poolShutdown) } } @@ -693,7 +703,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testLeasingMultipleStreamsFromOneConnectionWorks() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeasingMultipleStreamsFromOneConnectionWorks() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -737,7 +748,7 @@ final class ConnectionPoolTests: XCTestCase { } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 1) + #expect(Set(connectionLeases.lazy.map(\.connection.id)).count == 1) // release all 10 leased streams for lease in connectionLeases { @@ -758,7 +769,8 @@ final class ConnectionPoolTests: XCTestCase { } } - func testIncreasingAvailableStreamsWorks() async throws { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testIncreasingAvailableStreamsWorks() async throws { let clock = MockClock() let factory = MockConnectionFactory() let keepAliveDuration = Duration.seconds(30) @@ -808,7 +820,7 @@ final class ConnectionPoolTests: XCTestCase { } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connectionLease.lazy.map(\.connection.id)).count, 1) + #expect(Set(connectionLease.lazy.map(\.connection.id)).count == 1) requests = (22..<42).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index 2952bf8b..b4658df8 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -1,29 +1,29 @@ @testable import _ConnectionPoolModule import _ConnectionPoolTestUtils -import XCTest +import Testing -final class ConnectionRequestTests: XCTestCase { +@Suite struct ConnectionRequestTests { - func testHappyPath() async throws { + @Test func testHappyPath() async throws { let mockConnection = MockConnection(id: 1) let lease = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, any Error>) in let request = ConnectionRequest(id: 42, continuation: continuation) - XCTAssertEqual(request.id, 42) + #expect(request.id == 42) let lease = ConnectionLease(connection: mockConnection) { _ in } continuation.resume(with: .success(lease)) } - XCTAssert(lease.connection === mockConnection) + #expect(lease.connection === mockConnection) } - func testSadPath() async throws { + @Test func testSadPath() async throws { do { _ = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in continuation.resume(with: .failure(ConnectionPoolError.requestCancelled)) } - XCTFail("This point should not be reached") + Issue.record("This point should not be reached") } catch { - XCTAssertEqual(error as? ConnectionPoolError, .requestCancelled) + #expect(error as? ConnectionPoolError == .requestCancelled) } } } diff --git a/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift index 081e867b..ce620cc3 100644 --- a/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/Max2SequenceTests.swift @@ -1,60 +1,61 @@ @testable import _ConnectionPoolModule -import XCTest +import Testing -final class Max2SequenceTests: XCTestCase { - func testCountAndIsEmpty() async { +@Suite struct Max2SequenceTests { + + @Test func testCountAndIsEmpty() async { var sequence = Max2Sequence() - XCTAssertEqual(sequence.count, 0) - XCTAssertEqual(sequence.isEmpty, true) + #expect(sequence.count == 0) + #expect(sequence.isEmpty == true) sequence.append(1) - XCTAssertEqual(sequence.count, 1) - XCTAssertEqual(sequence.isEmpty, false) + #expect(sequence.count == 1) + #expect(sequence.isEmpty == false) sequence.append(2) - XCTAssertEqual(sequence.count, 2) - XCTAssertEqual(sequence.isEmpty, false) + #expect(sequence.count == 2) + #expect(sequence.isEmpty == false) } - func testOptionalInitializer() { + @Test func testOptionalInitializer() { let emptySequence = Max2Sequence(nil, nil) - XCTAssertEqual(emptySequence.count, 0) - XCTAssertEqual(emptySequence.isEmpty, true) + #expect(emptySequence.count == 0) + #expect(emptySequence.isEmpty == true) var emptySequenceIterator = emptySequence.makeIterator() - XCTAssertNil(emptySequenceIterator.next()) - XCTAssertNil(emptySequenceIterator.next()) - XCTAssertNil(emptySequenceIterator.next()) + #expect(emptySequenceIterator.next() == nil) + #expect(emptySequenceIterator.next() == nil) + #expect(emptySequenceIterator.next() == nil) let oneElemSequence1 = Max2Sequence(1, nil) - XCTAssertEqual(oneElemSequence1.count, 1) - XCTAssertEqual(oneElemSequence1.isEmpty, false) + #expect(oneElemSequence1.count == 1) + #expect(oneElemSequence1.isEmpty == false) var oneElemSequence1Iterator = oneElemSequence1.makeIterator() - XCTAssertEqual(oneElemSequence1Iterator.next(), 1) - XCTAssertNil(oneElemSequence1Iterator.next()) - XCTAssertNil(oneElemSequence1Iterator.next()) + #expect(oneElemSequence1Iterator.next() == 1) + #expect(oneElemSequence1Iterator.next() == nil) + #expect(oneElemSequence1Iterator.next() == nil) let oneElemSequence2 = Max2Sequence(nil, 2) - XCTAssertEqual(oneElemSequence2.count, 1) - XCTAssertEqual(oneElemSequence2.isEmpty, false) + #expect(oneElemSequence2.count == 1) + #expect(oneElemSequence2.isEmpty == false) var oneElemSequence2Iterator = oneElemSequence2.makeIterator() - XCTAssertEqual(oneElemSequence2Iterator.next(), 2) - XCTAssertNil(oneElemSequence2Iterator.next()) - XCTAssertNil(oneElemSequence2Iterator.next()) + #expect(oneElemSequence2Iterator.next() == 2) + #expect(oneElemSequence2Iterator.next() == nil) + #expect(oneElemSequence2Iterator.next() == nil) let twoElemSequence = Max2Sequence(1, 2) - XCTAssertEqual(twoElemSequence.count, 2) - XCTAssertEqual(twoElemSequence.isEmpty, false) + #expect(twoElemSequence.count == 2) + #expect(twoElemSequence.isEmpty == false) var twoElemSequenceIterator = twoElemSequence.makeIterator() - XCTAssertEqual(twoElemSequenceIterator.next(), 1) - XCTAssertEqual(twoElemSequenceIterator.next(), 2) - XCTAssertNil(twoElemSequenceIterator.next()) + #expect(twoElemSequenceIterator.next() == 1) + #expect(twoElemSequenceIterator.next() == 2) + #expect(twoElemSequenceIterator.next() == nil) } func testMap() { let twoElemSequence = Max2Sequence(1, 2).map({ "\($0)" }) - XCTAssertEqual(twoElemSequence.count, 2) - XCTAssertEqual(twoElemSequence.isEmpty, false) + #expect(twoElemSequence.count == 2) + #expect(twoElemSequence.isEmpty == false) var twoElemSequenceIterator = twoElemSequence.makeIterator() - XCTAssertEqual(twoElemSequenceIterator.next(), "1") - XCTAssertEqual(twoElemSequenceIterator.next(), "2") - XCTAssertNil(twoElemSequenceIterator.next()) + #expect(twoElemSequenceIterator.next() == "1") + #expect(twoElemSequenceIterator.next() == "2") + #expect(twoElemSequenceIterator.next() == nil) } } diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift index b1b54023..ef6b001a 100644 --- a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -1,11 +1,12 @@ import _ConnectionPoolModule import _ConnectionPoolTestUtils -import XCTest +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class NoKeepAliveBehaviorTests: XCTestCase { - func testNoKeepAlive() { + +@Suite struct NoKeepAliveBehaviorTests { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testNoKeepAlive() { let keepAliveBehavior = NoOpKeepAliveBehavior(connectionType: MockConnection.self) - XCTAssertNil(keepAliveBehavior.keepAliveFrequency) + #expect(keepAliveBehavior.keepAliveFrequency == nil) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index b09bfcb4..6bfe0f39 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -1,22 +1,12 @@ @testable import _ConnectionPoolModule import _ConnectionPoolTestUtils -import XCTest +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class PoolStateMachine_ConnectionGroupTests: XCTestCase { - var idGenerator: ConnectionIDGenerator! +@Suite struct PoolStateMachine_ConnectionGroupTests { + var idGenerator = ConnectionIDGenerator() - override func setUp() { - self.idGenerator = ConnectionIDGenerator() - super.setUp() - } - - override func tearDown() { - self.idGenerator = nil - super.tearDown() - } - - func testRefillConnections() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRefillConnections() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 4, @@ -26,35 +16,36 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { keepAliveReducesAvailableStreams: true ) - XCTAssertTrue(connections.isEmpty) + #expect(connections.isEmpty == true) let requests = connections.refillConnections() - XCTAssertFalse(connections.isEmpty) + #expect(connections.isEmpty == false) - XCTAssertEqual(requests.count, 4) - XCTAssertNil(connections.createNewDemandConnectionIfPossible()) - XCTAssertNil(connections.createNewOverflowConnectionIfPossible()) - XCTAssertEqual(connections.stats, .init(connecting: 4)) - XCTAssertEqual(connections.soonAvailableConnections, 4) + #expect(requests.count == 4) + #expect(connections.createNewDemandConnectionIfPossible() == nil) + #expect(connections.createNewOverflowConnectionIfPossible() == nil) + #expect(connections.stats == .init(connecting: 4)) + #expect(connections.soonAvailableConnections == 4) let requests2 = connections.refillConnections() - XCTAssertTrue(requests2.isEmpty) + #expect(requests2.isEmpty == true) var connected: UInt16 = 0 for request in requests { let newConnection = MockConnection(id: request.connectionID) let (_, context) = connections.newConnectionEstablished(newConnection, maxStreams: 1) - XCTAssertEqual(context.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(context.use, .persisted) + #expect(context.info == .idle(availableStreams: 1, newIdle: true)) + #expect(context.use == .persisted) connected += 1 - XCTAssertEqual(connections.stats, .init(connecting: 4 - connected, idle: connected, availableStreams: connected)) - XCTAssertEqual(connections.soonAvailableConnections, 4 - connected) + #expect(connections.stats == .init(connecting: 4 - connected, idle: connected, availableStreams: connected)) + #expect(connections.soonAvailableConnections == 4 - connected) } let requests3 = connections.refillConnections() - XCTAssertTrue(requests3.isEmpty) + #expect(requests3.isEmpty == true) } - func testMakeConnectionLeaseItAndDropItHappyPath() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testMakeConnectionLeaseItAndDropItHappyPath() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 0, @@ -65,70 +56,77 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) let requests = connections.refillConnections() - XCTAssertTrue(connections.isEmpty) - XCTAssertTrue(requests.isEmpty) + #expect(connections.isEmpty) + #expect(requests.isEmpty) guard let request = connections.createNewDemandConnectionIfPossible() else { - return XCTFail("Expected to receive a connection request") + Issue.record("Expected to receive a connection request") + return } - XCTAssertEqual(request, .init(connectionID: 0)) - XCTAssertFalse(connections.isEmpty) - XCTAssertEqual(connections.soonAvailableConnections, 1) - XCTAssertEqual(connections.stats, .init(connecting: 1)) + #expect(request == .init(connectionID: 0)) + #expect(!connections.isEmpty) + #expect(connections.soonAvailableConnections == 1) + #expect(connections.stats == .init(connecting: 1)) let newConnection = MockConnection(id: request.connectionID) let (_, establishedContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) - XCTAssertEqual(establishedContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedContext.use, .demand) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) - XCTAssertEqual(connections.soonAvailableConnections, 0) + #expect(establishedContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedContext.use == .demand) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) + #expect(connections.soonAvailableConnections == 0) guard case .leasedConnection(let leaseResult) = connections.leaseConnectionOrSoonAvailableConnectionCount() else { - return XCTFail("Expected to lease a connection") + Issue.record("Expected to lease a connection") + return } - XCTAssert(newConnection === leaseResult.connection) - XCTAssertEqual(connections.stats, .init(leased: 1, leasedStreams: 1)) + #expect(newConnection === leaseResult.connection) + #expect(connections.stats == .init(leased: 1, leasedStreams: 1)) guard let (index, releasedContext) = connections.releaseConnection(leaseResult.connection.id, streams: 1) else { - return XCTFail("Expected that this connection is still active") + Issue.record("Expected that this connection is still active") + return } - XCTAssertEqual(releasedContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(releasedContext.use, .demand) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(releasedContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(releasedContext.use == .demand) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) let parkTimers = connections.parkConnection(at: index, hasBecomeIdle: true) - XCTAssertEqual(parkTimers, [ + #expect(parkTimers == [ .init(timerID: 0, connectionID: newConnection.id, usecase: .keepAlive), .init(timerID: 1, connectionID: newConnection.id, usecase: .idleTimeout), ]) guard let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) else { - return XCTFail("Expected to get a connection for ping pong") + Issue.record("Expected to get a connection for ping pong") + return } - XCTAssert(newConnection === keepAliveAction.connection) - XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + #expect(newConnection === keepAliveAction.connection) + #expect(connections.stats == .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) guard let (_, pingPongContext) = connections.keepAliveSucceeded(newConnection.id) else { - return XCTFail("Expected to get an AvailableContext") + Issue.record("Expected to get an AvailableContext") + return } - XCTAssertEqual(pingPongContext.info, .idle(availableStreams: 1, newIdle: false)) - XCTAssertEqual(releasedContext.use, .demand) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(pingPongContext.info == .idle(availableStreams: 1, newIdle: false)) + #expect(releasedContext.use == .demand) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) guard let closeAction = connections.closeConnectionIfIdle(newConnection.id) else { - return XCTFail("Expected to get a connection for ping pong") + Issue.record("Expected to get a connection for ping pong") + return } - XCTAssertEqual(closeAction.timersToCancel, []) - XCTAssert(closeAction.connection === newConnection) - XCTAssertEqual(connections.stats, .init(closing: 1, availableStreams: 0)) + #expect(closeAction.timersToCancel == []) + #expect(closeAction.connection === newConnection) + #expect(connections.stats == .init(closing: 1, availableStreams: 0)) let closeContext = connections.connectionClosed(newConnection.id) - XCTAssertEqual(closeContext.connectionsStarting, 0) - XCTAssertTrue(connections.isEmpty) - XCTAssertEqual(connections.stats, .init()) + #expect(closeContext.connectionsStarting == 0) + #expect(connections.isEmpty) + #expect(connections.stats == .init()) } - func testBackoffDoneCreatesANewConnectionToReachMinimumConnectionsEvenThoughRetryIsSetToFalse() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testBackoffDoneCreatesANewConnectionToReachMinimumConnectionsEvenThoughRetryIsSetToFalse() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 1, @@ -139,26 +137,30 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) let requests = connections.refillConnections() - XCTAssertEqual(connections.stats, .init(connecting: 1)) - XCTAssertEqual(connections.soonAvailableConnections, 1) - XCTAssertFalse(connections.isEmpty) - XCTAssertEqual(requests.count, 1) - - guard let request = requests.first else { return XCTFail("Expected to receive a connection request") } - XCTAssertEqual(request, .init(connectionID: 0)) + #expect(connections.stats == .init(connecting: 1)) + #expect(connections.soonAvailableConnections == 1) + #expect(!connections.isEmpty) + #expect(requests.count == 1) + + guard let request = requests.first else { + Issue.record("Expected to receive a connection request") + return + } + #expect(request == .init(connectionID: 0)) let backoffTimer = connections.backoffNextConnectionAttempt(request.connectionID) - XCTAssertEqual(connections.stats, .init(backingOff: 1)) + #expect(connections.stats == .init(backingOff: 1)) let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) - XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + #expect(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken) == nil) let backoffDoneAction = connections.backoffDone(request.connectionID, retry: false) - XCTAssertEqual(backoffDoneAction, .createConnection(.init(connectionID: 0), backoffTimerCancellationToken)) + #expect(backoffDoneAction == .createConnection(.init(connectionID: 0), backoffTimerCancellationToken)) - XCTAssertEqual(connections.stats, .init(connecting: 1)) + #expect(connections.stats == .init(connecting: 1)) } - func testBackoffDoneCancelsIdleTimerIfAPersistedConnectionIsNotRetried() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testBackoffDoneCancelsIdleTimerIfAPersistedConnectionIsNotRetried() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 2, @@ -169,57 +171,60 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) let requests = connections.refillConnections() - XCTAssertEqual(connections.stats, .init(connecting: 2)) - XCTAssertEqual(connections.soonAvailableConnections, 2) - XCTAssertFalse(connections.isEmpty) - XCTAssertEqual(requests.count, 2) + #expect(connections.stats == .init(connecting: 2)) + #expect(connections.soonAvailableConnections == 2) + #expect(!connections.isEmpty) + #expect(requests.count == 2) var requestIterator = requests.makeIterator() guard let firstRequest = requestIterator.next(), let secondRequest = requestIterator.next() else { - return XCTFail("Expected to get two requests") + Issue.record("Expected to get two requests") + return } guard let thirdRequest = connections.createNewDemandConnectionIfPossible() else { - return XCTFail("Expected to get another request") + Issue.record("Expected to get another request") + return } - XCTAssertEqual(connections.stats, .init(connecting: 3)) + #expect(connections.stats == .init(connecting: 3)) let newSecondConnection = MockConnection(id: secondRequest.connectionID) let (_, establishedSecondConnectionContext) = connections.newConnectionEstablished(newSecondConnection, maxStreams: 1) - XCTAssertEqual(establishedSecondConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedSecondConnectionContext.use, .persisted) - XCTAssertEqual(connections.stats, .init(connecting: 2, idle: 1, availableStreams: 1)) - XCTAssertEqual(connections.soonAvailableConnections, 2) + #expect(establishedSecondConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedSecondConnectionContext.use == .persisted) + #expect(connections.stats == .init(connecting: 2, idle: 1, availableStreams: 1)) + #expect(connections.soonAvailableConnections == 2) let newThirdConnection = MockConnection(id: thirdRequest.connectionID) let (thirdConnectionIndex, establishedThirdConnectionContext) = connections.newConnectionEstablished(newThirdConnection, maxStreams: 1) - XCTAssertEqual(establishedThirdConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedThirdConnectionContext.use, .demand) - XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 2, availableStreams: 2)) - XCTAssertEqual(connections.soonAvailableConnections, 1) + #expect(establishedThirdConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedThirdConnectionContext.use == .demand) + #expect(connections.stats == .init(connecting: 1, idle: 2, availableStreams: 2)) + #expect(connections.soonAvailableConnections == 1) let thirdConnKeepTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: thirdRequest.connectionID, usecase: .keepAlive) let thirdConnIdleTimer = TestPoolStateMachine.ConnectionTimer(timerID: 1, connectionID: thirdRequest.connectionID, usecase: .idleTimeout) let thirdConnIdleTimerCancellationToken = MockTimerCancellationToken(thirdConnIdleTimer) - XCTAssertEqual(connections.parkConnection(at: thirdConnectionIndex, hasBecomeIdle: true), [thirdConnKeepTimer, thirdConnIdleTimer]) + #expect(connections.parkConnection(at: thirdConnectionIndex, hasBecomeIdle: true) == [thirdConnKeepTimer, thirdConnIdleTimer]) - XCTAssertNil(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer))) - XCTAssertNil(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken)) + #expect(connections.timerScheduled(thirdConnKeepTimer, cancelContinuation: .init(thirdConnKeepTimer)) == nil) + #expect(connections.timerScheduled(thirdConnIdleTimer, cancelContinuation: thirdConnIdleTimerCancellationToken) == nil) let backoffTimer = connections.backoffNextConnectionAttempt(firstRequest.connectionID) - XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + #expect(connections.stats == .init(backingOff: 1, idle: 2, availableStreams: 2)) let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) - XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) - XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 2, availableStreams: 2)) + #expect(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken) == nil) + #expect(connections.stats == .init(backingOff: 1, idle: 2, availableStreams: 2)) // connection three should be moved to connection one and for this reason become permanent - XCTAssertEqual(connections.backoffDone(firstRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken, thirdConnIdleTimerCancellationToken])) - XCTAssertEqual(connections.stats, .init(idle: 2, availableStreams: 2)) + #expect(connections.backoffDone(firstRequest.connectionID, retry: false) == .cancelTimers([backoffTimerCancellationToken, thirdConnIdleTimerCancellationToken])) + #expect(connections.stats == .init(idle: 2, availableStreams: 2)) - XCTAssertNil(connections.closeConnectionIfIdle(newThirdConnection.id)) + #expect(connections.closeConnectionIfIdle(newThirdConnection.id) == nil) } - func testBackoffDoneReturnsNilIfOverflowConnection() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testBackoffDoneReturnsNilIfOverflowConnection() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 0, @@ -230,33 +235,36 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { - return XCTFail("Expected to get two requests") + Issue.record("Expected to get two requests") + return } guard let secondRequest = connections.createNewDemandConnectionIfPossible() else { - return XCTFail("Expected to get another request") + Issue.record("Expected to get another request") + return } - XCTAssertEqual(connections.stats, .init(connecting: 2)) + #expect(connections.stats == .init(connecting: 2)) let newFirstConnection = MockConnection(id: firstRequest.connectionID) let (_, establishedFirstConnectionContext) = connections.newConnectionEstablished(newFirstConnection, maxStreams: 1) - XCTAssertEqual(establishedFirstConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedFirstConnectionContext.use, .demand) - XCTAssertEqual(connections.stats, .init(connecting: 1, idle: 1, availableStreams: 1)) - XCTAssertEqual(connections.soonAvailableConnections, 1) + #expect(establishedFirstConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedFirstConnectionContext.use == .demand) + #expect(connections.stats == .init(connecting: 1, idle: 1, availableStreams: 1)) + #expect(connections.soonAvailableConnections == 1) let backoffTimer = connections.backoffNextConnectionAttempt(secondRequest.connectionID) let backoffTimerCancellationToken = MockTimerCancellationToken(backoffTimer) - XCTAssertEqual(connections.stats, .init(backingOff: 1, idle: 1, availableStreams: 1)) - XCTAssertNil(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken)) + #expect(connections.stats == .init(backingOff: 1, idle: 1, availableStreams: 1)) + #expect(connections.timerScheduled(backoffTimer, cancelContinuation: backoffTimerCancellationToken) == nil) - XCTAssertEqual(connections.backoffDone(secondRequest.connectionID, retry: false), .cancelTimers([backoffTimerCancellationToken])) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(connections.backoffDone(secondRequest.connectionID, retry: false) == .cancelTimers([backoffTimerCancellationToken])) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) - XCTAssertNotNil(connections.closeConnectionIfIdle(newFirstConnection.id)) + #expect(connections.closeConnectionIfIdle(newFirstConnection.id) != nil) } - func testPingPong() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testPingPong() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 1, @@ -267,35 +275,40 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { ) let requests = connections.refillConnections() - XCTAssertFalse(connections.isEmpty) - XCTAssertEqual(connections.stats, .init(connecting: 1)) + #expect(!connections.isEmpty) + #expect(connections.stats == .init(connecting: 1)) - XCTAssertEqual(requests.count, 1) - guard let firstRequest = requests.first else { return XCTFail("Expected to have a request here") } + #expect(requests.count == 1) + guard let firstRequest = requests.first else { + Issue.record("Expected to have a request here") + return + } let newConnection = MockConnection(id: firstRequest.connectionID) let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) - XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(establishedConnectionContext.use, .persisted) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(establishedConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(establishedConnectionContext.use == .persisted) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) let timers = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertEqual(timers, [keepAliveTimer]) - XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(timers == [keepAliveTimer]) + #expect(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) - XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) - XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + #expect(keepAliveAction == .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + #expect(connections.stats == .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) guard let (_, afterPingIdleContext) = connections.keepAliveSucceeded(newConnection.id) else { - return XCTFail("Expected to receive an AvailableContext") + Issue.record("Expected to receive an AvailableContext") + return } - XCTAssertEqual(afterPingIdleContext.info, .idle(availableStreams: 1, newIdle: false)) - XCTAssertEqual(afterPingIdleContext.use, .persisted) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(afterPingIdleContext.info == .idle(availableStreams: 1, newIdle: false)) + #expect(afterPingIdleContext.use == .persisted) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) } - func testKeepAliveShouldNotIndicateCloseConnectionAfterClosed() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testKeepAliveShouldNotIndicateCloseConnectionAfterClosed() { var connections = TestPoolStateMachine.ConnectionGroup( generator: self.idGenerator, minimumConcurrentConnections: 0, @@ -305,24 +318,28 @@ final class PoolStateMachine_ConnectionGroupTests: XCTestCase { keepAliveReducesAvailableStreams: true ) - guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { return XCTFail("Expected to have a request here") } + guard let firstRequest = connections.createNewDemandConnectionIfPossible() else { + Issue.record("Expected to have a request here") + return + } let newConnection = MockConnection(id: firstRequest.connectionID) let (connectionIndex, establishedConnectionContext) = connections.newConnectionEstablished(newConnection, maxStreams: 1) - XCTAssertEqual(establishedConnectionContext.info, .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(connections.stats, .init(idle: 1, availableStreams: 1)) + #expect(establishedConnectionContext.info == .idle(availableStreams: 1, newIdle: true)) + #expect(connections.stats == .init(idle: 1, availableStreams: 1)) _ = connections.parkConnection(at: connectionIndex, hasBecomeIdle: true) let keepAliveTimer = TestPoolStateMachine.ConnectionTimer(timerID: 0, connectionID: firstRequest.connectionID, usecase: .keepAlive) let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertNil(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(connections.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) let keepAliveAction = connections.keepAliveIfIdle(newConnection.id) - XCTAssertEqual(keepAliveAction, .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) - XCTAssertEqual(connections.stats, .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) + #expect(keepAliveAction == .init(connection: newConnection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken)) + #expect(connections.stats == .init(idle: 1, runningKeepAlive: 1, availableStreams: 0)) _ = connections.closeConnectionIfIdle(newConnection.id) guard connections.keepAliveFailed(newConnection.id) == nil else { - return XCTFail("Expected keepAliveFailed not to cause close again") + Issue.record("Expected keepAliveFailed not to cause close again") + return } - XCTAssertEqual(connections.stats, .init(closing: 1)) + #expect(connections.stats == .init(closing: 1)) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index 7dd2b726..2d81cf38 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -1,36 +1,36 @@ @testable import _ConnectionPoolModule import _ConnectionPoolTestUtils -import XCTest +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class PoolStateMachine_ConnectionStateTests: XCTestCase { +@Suite struct PoolStateMachine_ConnectionStateTests { typealias TestConnectionState = TestPoolStateMachine.ConnectionState - func testStartupLeaseReleaseParkLease() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testStartupLeaseReleaseParkLease() { let connectionID = 1 var state = TestConnectionState(id: connectionID) - XCTAssertEqual(state.id, connectionID) - XCTAssertEqual(state.isIdle, false) - XCTAssertEqual(state.isAvailable, false) - XCTAssertEqual(state.isConnected, false) - XCTAssertEqual(state.isLeased, false) + #expect(state.id == connectionID) + #expect(!state.isIdle) + #expect(!state.isAvailable) + #expect(!state.isConnected) + #expect(!state.isLeased) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual(state.isIdle, true) - XCTAssertEqual(state.isAvailable, true) - XCTAssertEqual(state.isConnected, true) - XCTAssertEqual(state.isLeased, false) - XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) - - XCTAssertEqual(state.isIdle, false) - XCTAssertEqual(state.isAvailable, false) - XCTAssertEqual(state.isConnected, true) - XCTAssertEqual(state.isLeased, true) - - XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) + #expect(state.isIdle) + #expect(state.isAvailable) + #expect(state.isConnected) + #expect(state.isLeased == false) + #expect(state.lease(streams: 1) == .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + + #expect(!state.isIdle) + #expect(!state.isAvailable) + #expect(state.isConnected) + #expect(state.isLeased) + + #expect(state.release(streams: 1) == .idle(availableStreams: 1, newIdle: true)) let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) - XCTAssert( + #expect( parkResult.elementsEqual([ .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) @@ -38,31 +38,33 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { ) guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { - return XCTFail("Expected to get two timers") + Issue.record("Expected to get two timers") + return } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) - XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) + #expect(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken) == nil) let expectLeaseAction = TestConnectionState.LeaseAction( connection: connection, timersToCancel: [idleTimerCancellationToken, keepAliveTimerCancellationToken], wasIdle: true ) - XCTAssertEqual(state.lease(streams: 1), expectLeaseAction) + #expect(state.lease(streams: 1) == expectLeaseAction) } - func testStartupParkLeaseBeforeTimersRegistered() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testStartupParkLeaseBeforeTimersRegistered() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) - XCTAssertEqual( - parkResult, + #expect( + parkResult == [ .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) @@ -70,24 +72,26 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { ) guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { - return XCTFail("Expected to get two timers") + Issue.record("Expected to get two timers") + return } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) - XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + #expect(state.lease(streams: 1) == .init(connection: connection, timersToCancel: .init(), wasIdle: true)) - XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken), keepAliveTimerCancellationToken) - XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken), idleTimerCancellationToken) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == keepAliveTimerCancellationToken) + #expect(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken) == idleTimerCancellationToken) } - func testStartupParkLeasePark() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testStartupParkLeasePark() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) - XCTAssert( + #expect( parkResult.elementsEqual([ .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive), .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout) @@ -95,171 +99,186 @@ final class PoolStateMachine_ConnectionStateTests: XCTestCase { ) guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { - return XCTFail("Expected to get two timers") + Issue.record("Expected to get two timers") + return } let initialKeepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) let initialIdleTimerCancellationToken = MockTimerCancellationToken(idleTimer) - XCTAssertEqual(state.lease(streams: 1), .init(connection: connection, timersToCancel: .init(), wasIdle: true)) + #expect(state.lease(streams: 1) == .init(connection: connection, timersToCancel: .init(), wasIdle: true)) - XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 1, newIdle: true)) - XCTAssertEqual( - state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true), + #expect(state.release(streams: 1) == .idle(availableStreams: 1, newIdle: true)) + #expect( + state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) == [ .init(timerID: 2, connectionID: connectionID, usecase: .keepAlive), .init(timerID: 3, connectionID: connectionID, usecase: .idleTimeout) ] ) - XCTAssertEqual(state.timerScheduled(keepAliveTimer, cancelContinuation: initialKeepAliveTimerCancellationToken), initialKeepAliveTimerCancellationToken) - XCTAssertEqual(state.timerScheduled(idleTimer, cancelContinuation: initialIdleTimerCancellationToken), initialIdleTimerCancellationToken) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: initialKeepAliveTimerCancellationToken) == initialKeepAliveTimerCancellationToken) + #expect(state.timerScheduled(idleTimer, cancelContinuation: initialIdleTimerCancellationToken) == initialIdleTimerCancellationToken) } - func testStartupFailed() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testStartupFailed() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let firstBackoffTimer = state.failedToConnect() let firstBackoffTimerCancellationToken = MockTimerCancellationToken(firstBackoffTimer) - XCTAssertNil(state.timerScheduled(firstBackoffTimer, cancelContinuation: firstBackoffTimerCancellationToken)) - XCTAssertEqual(state.retryConnect(), firstBackoffTimerCancellationToken) + #expect(state.timerScheduled(firstBackoffTimer, cancelContinuation: firstBackoffTimerCancellationToken) == nil) + #expect(state.retryConnect() == firstBackoffTimerCancellationToken) let secondBackoffTimer = state.failedToConnect() let secondBackoffTimerCancellationToken = MockTimerCancellationToken(secondBackoffTimer) - XCTAssertNil(state.retryConnect()) - XCTAssertEqual( - state.timerScheduled(secondBackoffTimer, cancelContinuation: secondBackoffTimerCancellationToken), + #expect(state.retryConnect() == nil) + #expect( + state.timerScheduled(secondBackoffTimer, cancelContinuation: secondBackoffTimerCancellationToken) == secondBackoffTimerCancellationToken ) let thirdBackoffTimer = state.failedToConnect() let thirdBackoffTimerCancellationToken = MockTimerCancellationToken(thirdBackoffTimer) - XCTAssertNil(state.retryConnect()) + #expect(state.retryConnect() == nil) let forthBackoffTimer = state.failedToConnect() let forthBackoffTimerCancellationToken = MockTimerCancellationToken(forthBackoffTimer) - XCTAssertEqual( - state.timerScheduled(thirdBackoffTimer, cancelContinuation: thirdBackoffTimerCancellationToken), + #expect( + state.timerScheduled(thirdBackoffTimer, cancelContinuation: thirdBackoffTimerCancellationToken) == thirdBackoffTimerCancellationToken ) - XCTAssertNil( - state.timerScheduled(forthBackoffTimer, cancelContinuation: forthBackoffTimerCancellationToken) + #expect( + state.timerScheduled(forthBackoffTimer, cancelContinuation: forthBackoffTimerCancellationToken) == nil ) - XCTAssertEqual(state.retryConnect(), forthBackoffTimerCancellationToken) + #expect(state.retryConnect() == forthBackoffTimerCancellationToken) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) } - func testLeaseMultipleStreams() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testLeaseMultipleStreams() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + #expect(state.connected(connection, maxStreams: 100) == .idle(availableStreams: 100, newIdle: true)) let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) - guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + guard let keepAliveTimer = timers.first else { + Issue.record("Expected to get a keepAliveTimer") + return + } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) - XCTAssertEqual( - state.lease(streams: 30), + #expect( + state.lease(streams: 30) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [keepAliveTimerCancellationToken], wasIdle: true) ) - XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) + #expect(state.release(streams: 10) == .leased(availableStreams: 80)) - XCTAssertEqual( - state.lease(streams: 40), + #expect( + state.lease(streams: 40) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) ) - XCTAssertEqual( - state.lease(streams: 40), + #expect( + state.lease(streams: 40) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) ) - XCTAssertEqual(state.release(streams: 1), .leased(availableStreams: 1)) - XCTAssertEqual(state.release(streams: 98), .leased(availableStreams: 99)) - XCTAssertEqual(state.release(streams: 1), .idle(availableStreams: 100, newIdle: true)) + #expect(state.release(streams: 1) == .leased(availableStreams: 1)) + #expect(state.release(streams: 98) == .leased(availableStreams: 99)) + #expect(state.release(streams: 1) == .idle(availableStreams: 100, newIdle: true)) } - func testRunningKeepAliveReducesAvailableStreams() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRunningKeepAliveReducesAvailableStreams() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + #expect(state.connected(connection, maxStreams: 100) == .idle(availableStreams: 100, newIdle: true)) let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) - guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + guard let keepAliveTimer = timers.first else { + Issue.record("Expected to get a keepAliveTimer") + return + } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) - XCTAssertEqual( - state.runKeepAliveIfIdle(reducesAvailableStreams: true), + #expect( + state.runKeepAliveIfIdle(reducesAvailableStreams: true) == .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) ) - XCTAssertEqual( - state.lease(streams: 30), + #expect( + state.lease(streams: 30) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) ) - XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 79)) - XCTAssertEqual(state.isAvailable, true) - XCTAssertEqual( - state.lease(streams: 79), + #expect(state.release(streams: 10) == .leased(availableStreams: 79)) + #expect(state.isAvailable) + #expect( + state.lease(streams: 79) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: false) ) - XCTAssertEqual(state.isAvailable, false) - XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 1)) - XCTAssertEqual(state.isAvailable, true) + #expect(!state.isAvailable) + #expect(state.keepAliveSucceeded() == .leased(availableStreams: 1)) + #expect(state.isAvailable) } - func testRunningKeepAliveDoesNotReduceAvailableStreams() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRunningKeepAliveDoesNotReduceAvailableStreams() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 100), .idle(availableStreams: 100, newIdle: true)) + #expect(state.connected(connection, maxStreams: 100) == .idle(availableStreams: 100, newIdle: true)) let timers = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: false) - guard let keepAliveTimer = timers.first else { return XCTFail("Expected to get a keepAliveTimer") } + guard let keepAliveTimer = timers.first else { + Issue.record("Expected to get a keepAliveTimer") + return + } let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) - XCTAssertEqual( - state.runKeepAliveIfIdle(reducesAvailableStreams: false), + #expect( + state.runKeepAliveIfIdle(reducesAvailableStreams: false) == .init(connection: connection, keepAliveTimerCancellationContinuation: keepAliveTimerCancellationToken) ) - XCTAssertEqual( - state.lease(streams: 30), + #expect( + state.lease(streams: 30) == TestConnectionState.LeaseAction(connection: connection, timersToCancel: [], wasIdle: true) ) - XCTAssertEqual(state.release(streams: 10), .leased(availableStreams: 80)) - XCTAssertEqual(state.keepAliveSucceeded(), .leased(availableStreams: 80)) + #expect(state.release(streams: 10) == .leased(availableStreams: 80)) + #expect(state.keepAliveSucceeded() == .leased(availableStreams: 80)) } - func testRunKeepAliveRacesAgainstIdleClose() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRunKeepAliveRacesAgainstIdleClose() { let connectionID = 1 var state = TestConnectionState(id: connectionID) let connection = MockConnection(id: connectionID) - XCTAssertEqual(state.connected(connection, maxStreams: 1), .idle(availableStreams: 1, newIdle: true)) + #expect(state.connected(connection, maxStreams: 1) == .idle(availableStreams: 1, newIdle: true)) let parkResult = state.parkConnection(scheduleKeepAliveTimer: true, scheduleIdleTimeoutTimer: true) guard let keepAliveTimer = parkResult.first, let idleTimer = parkResult.second else { - return XCTFail("Expected to get two timers") + Issue.record("Expected to get two timers") + return } - XCTAssertEqual(keepAliveTimer, .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive)) - XCTAssertEqual(idleTimer, .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout)) + #expect(keepAliveTimer == .init(timerID: 0, connectionID: connectionID, usecase: .keepAlive)) + #expect(idleTimer == .init(timerID: 1, connectionID: connectionID, usecase: .idleTimeout)) let keepAliveTimerCancellationToken = MockTimerCancellationToken(keepAliveTimer) let idleTimerCancellationToken = MockTimerCancellationToken(idleTimer) - XCTAssertNil(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken)) - XCTAssertNil(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken)) - - XCTAssertEqual(state.closeIfIdle(), .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1, runningKeepAlive: false)) - XCTAssertEqual(state.runKeepAliveIfIdle(reducesAvailableStreams: true), .none) + #expect(state.timerScheduled(keepAliveTimer, cancelContinuation: keepAliveTimerCancellationToken) == nil) + #expect(state.timerScheduled(idleTimer, cancelContinuation: idleTimerCancellationToken) == nil) + #expect(state.closeIfIdle() == .init(connection: connection, previousConnectionState: .idle, cancelTimers: [keepAliveTimerCancellationToken, idleTimerCancellationToken], usedStreams: 0, maxStreams: 1, runningKeepAlive: false)) + #expect(state.runKeepAliveIfIdle(reducesAvailableStreams: true) == .none) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index ddd6a71e..458c6b3f 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -1,29 +1,30 @@ @testable import _ConnectionPoolModule import _ConnectionPoolTestUtils -import XCTest +import Testing -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class PoolStateMachine_RequestQueueTests: XCTestCase { +@Suite struct PoolStateMachine_RequestQueueTests { typealias TestQueue = TestPoolStateMachine.RequestQueue - func testHappyPath() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testHappyPath() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) let request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - XCTAssertEqual(queue.count, 1) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 1) + #expect(!queue.isEmpty) let popResult = queue.pop(max: 3) - XCTAssert(popResult.elementsEqual([request1])) - XCTAssert(queue.isEmpty) - XCTAssertEqual(queue.count, 0) + #expect(popResult.elementsEqual([request1])) + #expect(queue.isEmpty) + #expect(queue.count == 0) } - func testEnqueueAndPopMultipleRequests() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testEnqueueAndPopMultipleRequests() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) @@ -33,21 +34,22 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { queue.queue(request3) do { - XCTAssertEqual(queue.count, 3) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 3) + #expect(!queue.isEmpty) let popResult = queue.pop(max: 3) - XCTAssert(popResult.elementsEqual([request1, request2, request3])) - XCTAssert(queue.isEmpty) - XCTAssertEqual(queue.count, 0) + #expect(popResult.elementsEqual([request1, request2, request3])) + #expect(queue.isEmpty) + #expect(queue.count == 0) } - XCTAssert(isKnownUniquelyReferenced(&request1)) - XCTAssert(isKnownUniquelyReferenced(&request2)) - XCTAssert(isKnownUniquelyReferenced(&request3)) + #expect(isKnownUniquelyReferenced(&request1)) + #expect(isKnownUniquelyReferenced(&request2)) + #expect(isKnownUniquelyReferenced(&request3)) } - func testEnqueueAndPopOnlyOne() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testEnqueueAndPopOnlyOne() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) @@ -57,24 +59,25 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { queue.queue(request3) do { - XCTAssertEqual(queue.count, 3) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 3) + #expect(!queue.isEmpty) let popResult = queue.pop(max: 1) - XCTAssert(popResult.elementsEqual([request1])) - XCTAssertFalse(queue.isEmpty) - XCTAssertEqual(queue.count, 2) + #expect(popResult.elementsEqual([request1])) + #expect(!queue.isEmpty) + #expect(queue.count == 2) let removeAllResult = queue.removeAll() - XCTAssert(Set(removeAllResult) == [request2, request3]) + #expect(Set(removeAllResult) == [request2, request3]) } - XCTAssert(isKnownUniquelyReferenced(&request1)) - XCTAssert(isKnownUniquelyReferenced(&request2)) - XCTAssert(isKnownUniquelyReferenced(&request3)) + #expect(isKnownUniquelyReferenced(&request1)) + #expect(isKnownUniquelyReferenced(&request2)) + #expect(isKnownUniquelyReferenced(&request3)) } - func testCancellation() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testCancellation() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) @@ -84,34 +87,35 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { queue.queue(request3) do { - XCTAssertEqual(queue.count, 3) + #expect(queue.count == 3) let returnedRequest2 = queue.remove(request2.id) - XCTAssert(returnedRequest2 === request2) - XCTAssertEqual(queue.count, 2) - XCTAssertFalse(queue.isEmpty) + #expect(returnedRequest2 === request2) + #expect(queue.count == 2) + #expect(!queue.isEmpty) } // still retained by the deque inside the queue - XCTAssertEqual(queue.requests.count, 2) - XCTAssertEqual(queue.queue.count, 3) + #expect(queue.requests.count == 2) + #expect(queue.queue.count == 3) do { - XCTAssertEqual(queue.count, 2) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 2) + #expect(!queue.isEmpty) let popResult = queue.pop(max: 3) - XCTAssert(popResult.elementsEqual([request1, request3])) - XCTAssert(queue.isEmpty) - XCTAssertEqual(queue.count, 0) + #expect(popResult.elementsEqual([request1, request3])) + #expect(queue.isEmpty) + #expect(queue.count == 0) } - XCTAssert(isKnownUniquelyReferenced(&request1)) - XCTAssert(isKnownUniquelyReferenced(&request2)) - XCTAssert(isKnownUniquelyReferenced(&request3)) + #expect(isKnownUniquelyReferenced(&request1)) + #expect(isKnownUniquelyReferenced(&request2)) + #expect(isKnownUniquelyReferenced(&request3)) } - func testRemoveAllAfterCancellation() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testRemoveAllAfterCancellation() { var queue = TestQueue() - XCTAssert(queue.isEmpty) + #expect(queue.isEmpty) var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) @@ -121,28 +125,28 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { queue.queue(request3) do { - XCTAssertEqual(queue.count, 3) + #expect(queue.count == 3) let returnedRequest2 = queue.remove(request2.id) - XCTAssert(returnedRequest2 === request2) - XCTAssertEqual(queue.count, 2) - XCTAssertFalse(queue.isEmpty) + #expect(returnedRequest2 === request2) + #expect(queue.count == 2) + #expect(!queue.isEmpty) } // still retained by the deque inside the queue - XCTAssertEqual(queue.requests.count, 2) - XCTAssertEqual(queue.queue.count, 3) + #expect(queue.requests.count == 2) + #expect(queue.queue.count == 3) do { - XCTAssertEqual(queue.count, 2) - XCTAssertFalse(queue.isEmpty) + #expect(queue.count == 2) + #expect(!queue.isEmpty) let removeAllResult = queue.removeAll() - XCTAssert(Set(removeAllResult) == [request1, request3]) - XCTAssert(queue.isEmpty) - XCTAssertEqual(queue.count, 0) + #expect(Set(removeAllResult) == [request1, request3]) + #expect(queue.isEmpty) + #expect(queue.count == 0) } - XCTAssert(isKnownUniquelyReferenced(&request1)) - XCTAssert(isKnownUniquelyReferenced(&request2)) - XCTAssert(isKnownUniquelyReferenced(&request3)) + #expect(isKnownUniquelyReferenced(&request1)) + #expect(isKnownUniquelyReferenced(&request2)) + #expect(isKnownUniquelyReferenced(&request3)) } } diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index 08afdf8e..c748de28 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,6 +1,6 @@ @testable import _ConnectionPoolModule import _ConnectionPoolTestUtils -import XCTest +import Testing @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) typealias TestPoolStateMachine = PoolStateMachine< @@ -12,10 +12,10 @@ typealias TestPoolStateMachine = PoolStateMachine< MockTimerCancellationToken > -@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class PoolStateMachineTests: XCTestCase { +@Suite struct PoolStateMachineTests { - func testConnectionsAreCreatedAndParkedOnStartup() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testConnectionsAreCreatedAndParkedOnStartup() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 2 configuration.maximumConnectionSoftLimit = 4 @@ -33,25 +33,26 @@ final class PoolStateMachineTests: XCTestCase { do { let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 2) + #expect(requests.count == 2) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) let connection1KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 0, usecase: .keepAlive), duration: .seconds(10)) let connection1KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection1KeepAliveTimer) - XCTAssertEqual(createdAction1.request, .none) - XCTAssertEqual(createdAction1.connection, .scheduleTimers([connection1KeepAliveTimer])) + #expect(createdAction1.request == .none) + #expect(createdAction1.connection == .scheduleTimers([connection1KeepAliveTimer])) - XCTAssertEqual(stateMachine.timerScheduled(connection1KeepAliveTimer, cancelContinuation: connection1KeepAliveTimerCancellationToken), .none) + #expect(stateMachine.timerScheduled(connection1KeepAliveTimer, cancelContinuation: connection1KeepAliveTimerCancellationToken) == .none) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) let connection2KeepAliveTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: .seconds(10)) let connection2KeepAliveTimerCancellationToken = MockTimerCancellationToken(connection2KeepAliveTimer) - XCTAssertEqual(createdAction2.request, .none) - XCTAssertEqual(createdAction2.connection, .scheduleTimers([connection2KeepAliveTimer])) - XCTAssertEqual(stateMachine.timerScheduled(connection2KeepAliveTimer, cancelContinuation: connection2KeepAliveTimerCancellationToken), .none) + #expect(createdAction2.request == .none) + #expect(createdAction2.connection == .scheduleTimers([connection2KeepAliveTimer])) + #expect(stateMachine.timerScheduled(connection2KeepAliveTimer, cancelContinuation: connection2KeepAliveTimerCancellationToken) == .none) } } - func testConnectionsNoKeepAliveRun() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testConnectionsNoKeepAliveRun() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 1 configuration.maximumConnectionSoftLimit = 4 @@ -69,51 +70,52 @@ final class PoolStateMachineTests: XCTestCase { // refill pool to at least one connection let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 1) + #expect(requests.count == 1) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .none) - XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + #expect(createdAction1.request == .none) + #expect(createdAction1.connection == .scheduleTimers([])) // lease connection 1 let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) - XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + #expect(leaseRequest1.connection == .cancelTimers([])) + #expect(leaseRequest1.request == .leaseConnection(.init(element: request1), connection1)) // release connection 1 - XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) + #expect(stateMachine.releaseConnection(connection1, streams: 1) == .none()) // lease connection 1 let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) - XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) - XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) + #expect(leaseRequest2.connection == .cancelTimers([])) + #expect(leaseRequest2.request == .leaseConnection(.init(element: request2), connection1)) // request connection while none is available let request3 = MockRequest(connectionType: MockConnection.self) let leaseRequest3 = stateMachine.leaseConnection(request3) - XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) - XCTAssertEqual(leaseRequest3.request, .none) + #expect(leaseRequest3.connection == .makeConnection(.init(connectionID: 1), [])) + #expect(leaseRequest3.request == .none) // make connection 2 and lease immediately let connection2 = MockConnection(id: 1) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) - XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request3), connection2)) - XCTAssertEqual(createdAction2.connection, .none) + #expect(createdAction2.request == .leaseConnection(.init(element: request3), connection2)) + #expect(createdAction2.connection == .none) // release connection 2 let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) - XCTAssertEqual( - stateMachine.releaseConnection(connection2, streams: 1), + #expect( + stateMachine.releaseConnection(connection2, streams: 1) == .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) ) - XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) - XCTAssertEqual(stateMachine.timerTriggered(connection2IdleTimer), .init(request: .none, connection: .closeConnection(connection2, [connection2IdleTimerCancellationToken]))) + #expect(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken) == .none) + #expect(stateMachine.timerTriggered(connection2IdleTimer) == .init(request: .none, connection: .closeConnection(connection2, [connection2IdleTimerCancellationToken]))) } - func testOnlyOverflowConnections() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testOnlyOverflowConnections() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 0 configuration.maximumConnectionSoftLimit = 0 @@ -129,49 +131,50 @@ final class PoolStateMachineTests: XCTestCase { // don't refill pool let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 0) + #expect(requests.count == 0) // request connection while none exists let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) - XCTAssertEqual(leaseRequest1.request, .none) + #expect(leaseRequest1.connection == .makeConnection(.init(connectionID: 0), [])) + #expect(leaseRequest1.request == .none) // make connection 1 and lease immediately let connection1 = MockConnection(id: 0) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) - XCTAssertEqual(createdAction1.connection, .none) + #expect(createdAction1.request == .leaseConnection(.init(element: request1), connection1)) + #expect(createdAction1.connection == .none) // request connection while none is available let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) - XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) - XCTAssertEqual(leaseRequest2.request, .none) + #expect(leaseRequest2.connection == .makeConnection(.init(connectionID: 1), [])) + #expect(leaseRequest2.request == .none) // release connection 1 should be leased again immediately let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) - XCTAssertEqual(releaseRequest1.request, .leaseConnection(.init(element: request2), connection1)) - XCTAssertEqual(releaseRequest1.connection, .none) + #expect(releaseRequest1.request == .leaseConnection(.init(element: request2), connection1)) + #expect(releaseRequest1.connection == .none) // connection 2 comes up and should be closed right away let connection2 = MockConnection(id: 1) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) - XCTAssertEqual(createdAction2.request, .none) - XCTAssertEqual(createdAction2.connection, .closeConnection(connection2, [])) - XCTAssertEqual(stateMachine.connectionClosed(connection2), .none()) + #expect(createdAction2.request == .none) + #expect(createdAction2.connection == .closeConnection(connection2, [])) + #expect(stateMachine.connectionClosed(connection2) == .none()) // release connection 1 should be closed as well let releaseRequest2 = stateMachine.releaseConnection(connection1, streams: 1) - XCTAssertEqual(releaseRequest2.request, .none) - XCTAssertEqual(releaseRequest2.connection, .closeConnection(connection1, [])) + #expect(releaseRequest2.request == .none) + #expect(releaseRequest2.connection == .closeConnection(connection1, [])) let shutdownAction = stateMachine.triggerForceShutdown() - XCTAssertEqual(shutdownAction.request, .failRequests(.init(), .poolShutdown)) - XCTAssertEqual(shutdownAction.connection, .shutdown(.init())) + #expect(shutdownAction.request == .failRequests(.init(), .poolShutdown)) + #expect(shutdownAction.connection == .shutdown(.init())) } - func testDemandConnectionIsMadePermanentIfPermanentIsClose() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testDemandConnectionIsMadePermanentIfPermanentIsClose() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 1 configuration.maximumConnectionSoftLimit = 2 @@ -189,44 +192,45 @@ final class PoolStateMachineTests: XCTestCase { // refill pool to at least one connection let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 1) + #expect(requests.count == 1) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .none) - XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) + #expect(createdAction1.request == .none) + #expect(createdAction1.connection == .scheduleTimers([])) // lease connection 1 let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) - XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) + #expect(leaseRequest1.connection == .cancelTimers([])) + #expect(leaseRequest1.request == .leaseConnection(.init(element: request1), connection1)) // request connection while none is available let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) - XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) - XCTAssertEqual(leaseRequest2.request, .none) + #expect(leaseRequest2.connection == .makeConnection(.init(connectionID: 1), [])) + #expect(leaseRequest2.request == .none) // make connection 2 and lease immediately let connection2 = MockConnection(id: 1) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) - XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) - XCTAssertEqual(createdAction2.connection, .none) + #expect(createdAction2.request == .leaseConnection(.init(element: request2), connection2)) + #expect(createdAction2.connection == .none) // release connection 2 let connection2IdleTimer = TestPoolStateMachine.Timer(.init(timerID: 0, connectionID: 1, usecase: .idleTimeout), duration: configuration.idleTimeoutDuration) let connection2IdleTimerCancellationToken = MockTimerCancellationToken(connection2IdleTimer) - XCTAssertEqual( - stateMachine.releaseConnection(connection2, streams: 1), + #expect( + stateMachine.releaseConnection(connection2, streams: 1) == .init(request: .none, connection: .scheduleTimers([connection2IdleTimer])) ) - XCTAssertEqual(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken), .none) + #expect(stateMachine.timerScheduled(connection2IdleTimer, cancelContinuation: connection2IdleTimerCancellationToken) == .none) // connection 1 is dropped - XCTAssertEqual(stateMachine.connectionClosed(connection1), .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) + #expect(stateMachine.connectionClosed(connection1) == .init(request: .none, connection: .cancelTimers([connection2IdleTimerCancellationToken]))) } - func testReleaseLoosesRaceAgainstClosed() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testReleaseLoosesRaceAgainstClosed() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 0 configuration.maximumConnectionSoftLimit = 2 @@ -242,32 +246,33 @@ final class PoolStateMachineTests: XCTestCase { // don't refill pool let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 0) + #expect(requests.count == 0) // request connection while none exists let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) - XCTAssertEqual(leaseRequest1.request, .none) + #expect(leaseRequest1.connection == .makeConnection(.init(connectionID: 0), [])) + #expect(leaseRequest1.request == .none) // make connection 1 and lease immediately let connection1 = MockConnection(id: 0) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) - XCTAssertEqual(createdAction1.connection, .none) + #expect(createdAction1.request == .leaseConnection(.init(element: request1), connection1)) + #expect(createdAction1.connection == .none) // connection got closed let closedAction = stateMachine.connectionClosed(connection1) - XCTAssertEqual(closedAction.connection, .none) - XCTAssertEqual(closedAction.request, .none) + #expect(closedAction.connection == .none) + #expect(closedAction.request == .none) // release connection 1 should be leased again immediately let releaseRequest1 = stateMachine.releaseConnection(connection1, streams: 1) - XCTAssertEqual(releaseRequest1.request, .none) - XCTAssertEqual(releaseRequest1.connection, .none) + #expect(releaseRequest1.request == .none) + #expect(releaseRequest1.connection == .none) } - func testKeepAliveOnClosingConnection() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testKeepAliveOnClosingConnection() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 0 configuration.maximumConnectionSoftLimit = 2 @@ -275,7 +280,6 @@ final class PoolStateMachineTests: XCTestCase { configuration.keepAliveDuration = .seconds(2) configuration.idleTimeoutDuration = .seconds(4) - var stateMachine = TestPoolStateMachine( configuration: configuration, generator: .init(), @@ -284,46 +288,46 @@ final class PoolStateMachineTests: XCTestCase { // don't refill pool let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 0) + #expect(requests.count == 0) // request connection while none exists let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) - XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) - XCTAssertEqual(leaseRequest1.request, .none) + #expect(leaseRequest1.connection == .makeConnection(.init(connectionID: 0), [])) + #expect(leaseRequest1.request == .none) // make connection 1 let connection1 = MockConnection(id: 0) let createdAction1 = stateMachine.connectionEstablished(connection1, maxStreams: 1) - XCTAssertEqual(createdAction1.request, .leaseConnection(.init(element: request1), connection1)) - XCTAssertEqual(createdAction1.connection, .none) + #expect(createdAction1.request == .leaseConnection(.init(element: request1), connection1)) + #expect(createdAction1.connection == .none) _ = stateMachine.releaseConnection(connection1, streams: 1) // trigger keep alive let keepAliveAction1 = stateMachine.connectionKeepAliveTimerTriggered(connection1.id) - XCTAssertEqual(keepAliveAction1.connection, .runKeepAlive(connection1, nil)) + #expect(keepAliveAction1.connection == .runKeepAlive(connection1, nil)) // fail keep alive and cause closed let keepAliveFailed1 = stateMachine.connectionKeepAliveFailed(connection1.id) - XCTAssertEqual(keepAliveFailed1.connection, .closeConnection(connection1, [])) + #expect(keepAliveFailed1.connection == .closeConnection(connection1, [])) connection1.closeIfClosing() // request connection while none exists anymore let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) - XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) - XCTAssertEqual(leaseRequest2.request, .none) + #expect(leaseRequest2.connection == .makeConnection(.init(connectionID: 1), [])) + #expect(leaseRequest2.request == .none) // make connection 2 let connection2 = MockConnection(id: 1) let createdAction2 = stateMachine.connectionEstablished(connection2, maxStreams: 1) - XCTAssertEqual(createdAction2.request, .leaseConnection(.init(element: request2), connection2)) - XCTAssertEqual(createdAction2.connection, .none) + #expect(createdAction2.request == .leaseConnection(.init(element: request2), connection2)) + #expect(createdAction2.connection == .none) _ = stateMachine.releaseConnection(connection2, streams: 1) // trigger keep alive while connection is still open let keepAliveAction2 = stateMachine.connectionKeepAliveTimerTriggered(connection2.id) - XCTAssertEqual(keepAliveAction2.connection, .runKeepAlive(connection2, nil)) + #expect(keepAliveAction2.connection == .runKeepAlive(connection2, nil)) // close connection in the middle of keep alive connection2.close() @@ -331,10 +335,11 @@ final class PoolStateMachineTests: XCTestCase { // fail keep alive and cause closed let keepAliveFailed2 = stateMachine.connectionKeepAliveFailed(connection2.id) - XCTAssertEqual(keepAliveFailed2.connection, .closeConnection(connection2, [])) + #expect(keepAliveFailed2.connection == .closeConnection(connection2, [])) } - func testConnectionIsEstablishedAfterFailedKeepAliveIfNotEnoughConnectionsLeft() { + @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) + @Test func testConnectionIsEstablishedAfterFailedKeepAliveIfNotEnoughConnectionsLeft() { var configuration = PoolConfiguration() configuration.minimumConnectionCount = 1 configuration.maximumConnectionSoftLimit = 2 @@ -351,35 +356,38 @@ final class PoolStateMachineTests: XCTestCase { // refill pool let requests = stateMachine.refillConnections() - XCTAssertEqual(requests.count, 1) + #expect(requests.count == 1) // one connection should exist let request = MockRequest(connectionType: MockConnection.self) let leaseRequest = stateMachine.leaseConnection(request) - XCTAssertEqual(leaseRequest.connection, .none) - XCTAssertEqual(leaseRequest.request, .none) + #expect(leaseRequest.connection == .none) + #expect(leaseRequest.request == .none) // make connection 1 let connection = MockConnection(id: 0) let createdAction = stateMachine.connectionEstablished(connection, maxStreams: 1) - XCTAssertEqual(createdAction.request, .leaseConnection(.init(element: request), connection)) - XCTAssertEqual(createdAction.connection, .none) + #expect(createdAction.request == .leaseConnection(.init(element: request), connection)) + #expect(createdAction.connection == .none) _ = stateMachine.releaseConnection(connection, streams: 1) // trigger keep alive let keepAliveAction = stateMachine.connectionKeepAliveTimerTriggered(connection.id) - XCTAssertEqual(keepAliveAction.connection, .runKeepAlive(connection, nil)) + #expect(keepAliveAction.connection == .runKeepAlive(connection, nil)) // fail keep alive, cause closed and make new connection let keepAliveFailed = stateMachine.connectionKeepAliveFailed(connection.id) - XCTAssertEqual(keepAliveFailed.connection, .closeConnection(connection, [])) + #expect(keepAliveFailed.connection == .closeConnection(connection, [])) let connectionClosed = stateMachine.connectionClosed(connection) - XCTAssertEqual(connectionClosed.connection, .makeConnection(.init(connectionID: 1), [])) + #expect(connectionClosed.connection == .makeConnection(.init(connectionID: 1), [])) connection.closeIfClosing() let establishAction = stateMachine.connectionEstablished(.init(id: 1), maxStreams: 1) - XCTAssertEqual(establishAction.request, .none) - guard case .scheduleTimers(let timers) = establishAction.connection else { return XCTFail("Unexpected connection action") } - XCTAssertEqual(timers, [.init(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: configuration.keepAliveDuration!)]) + #expect(establishAction.request == .none) + if case .scheduleTimers(let timers) = establishAction.connection { + #expect(timers == [.init(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: configuration.keepAliveDuration!)]) + } else { + Issue.record("Unexpected connection action") + } } } diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index 6be22005..9dfac549 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -1,80 +1,84 @@ @testable import _ConnectionPoolModule -import XCTest +import Testing -final class TinyFastSequenceTests: XCTestCase { - func testCountIsEmptyAndIterator() async { +@Suite struct TinyFastSequenceTests { + @Test func testCountIsEmptyAndIterator() { var sequence = TinyFastSequence() - XCTAssertEqual(sequence.count, 0) - XCTAssertEqual(sequence.isEmpty, true) - XCTAssertEqual(sequence.first, nil) - XCTAssertEqual(Array(sequence), []) + #expect(sequence.count == 0) + #expect(sequence.isEmpty == true) + #expect(sequence.first == nil) + #expect(Array(sequence) == []) sequence.append(1) - XCTAssertEqual(sequence.count, 1) - XCTAssertEqual(sequence.isEmpty, false) - XCTAssertEqual(sequence.first, 1) - XCTAssertEqual(Array(sequence), [1]) + #expect(sequence.count == 1) + #expect(sequence.isEmpty == false) + #expect(sequence.first == 1) + #expect(Array(sequence) == [1]) sequence.append(2) - XCTAssertEqual(sequence.count, 2) - XCTAssertEqual(sequence.isEmpty, false) - XCTAssertEqual(sequence.first, 1) - XCTAssertEqual(Array(sequence), [1, 2]) + #expect(sequence.count == 2) + #expect(sequence.isEmpty == false) + #expect(sequence.first == 1) + #expect(Array(sequence) == [1, 2]) sequence.append(3) - XCTAssertEqual(sequence.count, 3) - XCTAssertEqual(sequence.isEmpty, false) - XCTAssertEqual(sequence.first, 1) - XCTAssertEqual(Array(sequence), [1, 2, 3]) + #expect(sequence.count == 3) + #expect(sequence.isEmpty == false) + #expect(sequence.first == 1) + #expect(Array(sequence) == [1, 2, 3]) } - func testReserveCapacityIsForwarded() { + @Test func testReserveCapacityIsForwarded() { var emptySequence = TinyFastSequence() emptySequence.reserveCapacity(8) emptySequence.append(1) emptySequence.append(2) emptySequence.append(3) guard case .n(let array) = emptySequence.base else { - return XCTFail("Expected sequence to be backed by an array") + Issue.record("Expected sequence to be backed by an array") + return } - XCTAssertGreaterThanOrEqual(array.capacity, 8) + #expect(array.capacity >= 8) var oneElemSequence = TinyFastSequence(element: 1) oneElemSequence.reserveCapacity(8) oneElemSequence.append(2) oneElemSequence.append(3) guard case .n(let array) = oneElemSequence.base else { - return XCTFail("Expected sequence to be backed by an array") + Issue.record("Expected sequence to be backed by an array") + return } - XCTAssertGreaterThanOrEqual(array.capacity, 8) + #expect(array.capacity >= 8) var twoElemSequence = TinyFastSequence([1, 2]) twoElemSequence.reserveCapacity(8) twoElemSequence.append(3) guard case .n(let array) = twoElemSequence.base else { - return XCTFail("Expected sequence to be backed by an array") + Issue.record("Expected sequence to be backed by an array") + return } - XCTAssertGreaterThanOrEqual(array.capacity, 8) + #expect(array.capacity >= 8) var threeElemSequence = TinyFastSequence([1, 2, 3]) threeElemSequence.reserveCapacity(8) guard case .n(let array) = threeElemSequence.base else { - return XCTFail("Expected sequence to be backed by an array") + Issue.record("Expected sequence to be backed by an array") + return } - XCTAssertGreaterThanOrEqual(array.capacity, 8) + #expect(array.capacity >= 8) } - func testNewSequenceSlowPath() { + @Test func testNewSequenceSlowPath() { let sequence = TinyFastSequence("AB".utf8) - XCTAssertEqual(Array(sequence), [UInt8(ascii: "A"), UInt8(ascii: "B")]) + #expect(Array(sequence) == [UInt8(ascii: "A"), UInt8(ascii: "B")]) } - func testSingleItem() { + @Test func testSingleItem() { let sequence = TinyFastSequence("A".utf8) - XCTAssertEqual(Array(sequence), [UInt8(ascii: "A")]) + #expect(Array(sequence) == [UInt8(ascii: "A")]) } - func testEmptyCollection() { + @Test func testEmptyCollection() { let sequence = TinyFastSequence("".utf8) - XCTAssertTrue(sequence.isEmpty) - XCTAssertEqual(sequence.count, 0) - XCTAssertEqual(Array(sequence), []) + #expect(sequence.isEmpty == true) + #expect(sequence.count == 0) + #expect(Array(sequence) == []) } } From 4ebc10cb4e8b199083f4a6251896126a226e52f8 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 18 Sep 2025 11:40:23 +0200 Subject: [PATCH 289/292] Add public `triggerForceShutdown` to `ConnectionPool` (#572) Add distinct method that we need in Valkey (valkey-io/valkey-swift#85) --- Sources/ConnectionPoolModule/ConnectionPool.swift | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index ee72337d..40d52a5a 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -283,12 +283,15 @@ public final class ConnectionPool< await self.run(in: &taskGroup) } } onCancel: { - let actions = self.stateBox.withLockedValue { state in - state.stateMachine.triggerForceShutdown() - } + self.triggerForceShutdown() + } + } - self.runStateMachineActions(actions) + public func triggerForceShutdown() { + let actions = self.stateBox.withLockedValue { state in + state.stateMachine.triggerForceShutdown() } + self.runStateMachineActions(actions) } // MARK: - Private Methods - From a7310fa9fde8aefc57f9ea841da1d1128cdb29e0 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 18 Sep 2025 17:49:04 +0200 Subject: [PATCH 290/292] PostgresNIOTests: Move first files to Swift Testing (#590) --- .../AuthenticationStateMachineTests.swift | 102 ++++----- .../ConnectionStateMachineTests.swift | 205 +++++++++--------- .../New/Data/Array+PSQLCodableTests.swift | 166 +++++++------- .../New/Data/Bool+PSQLCodableTests.swift | 70 +++--- .../New/Data/Bytes+PSQLCodableTests.swift | 29 +-- .../New/Messages/AuthenticationTests.swift | 16 +- .../New/Messages/BackendKeyDataTests.swift | 26 ++- .../New/Messages/BindTests.swift | 50 ++--- .../New/Messages/CancelTests.swift | 19 +- .../New/Messages/CloseTests.swift | 32 +-- .../New/Messages/CopyTests.swift | 70 +++--- .../New/Messages/DataRowTests.swift | 80 +++---- .../New/Messages/DescribeTests.swift | 35 ++- .../New/Messages/ExecuteTests.swift | 17 +- .../New/Messages/ParseTests.swift | 20 +- .../New/Messages/PasswordTests.swift | 15 +- .../Messages/SASLInitialResponseTests.swift | 36 +-- .../New/Messages/SASLResponseTests.swift | 26 +-- .../New/Messages/StartupTests.swift | 61 +++--- .../New/PostgresCellTests.swift | 69 ++++-- .../New/PostgresRowSequenceTests.swift | 116 +++++----- 21 files changed, 648 insertions(+), 612 deletions(-) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index df881f90..99f7f5e9 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -1,82 +1,84 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class AuthenticationStateMachineTests: XCTestCase { - - func testAuthenticatePlaintext() { +@Suite struct AuthenticationStateMachineTests { + + @Test func testAuthenticatePlaintext() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) + + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(.plaintext) == .sendPasswordMessage(.cleartext, authContext)) + #expect(state.authenticationMessageReceived(.ok) == .wait) } - func testAuthenticateMD5() { + @Test func testAuthenticateMD5() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) let salt: UInt32 = 0x00_01_02_03 - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(.md5(salt: salt)) == .sendPasswordMessage(.md5(salt: salt), authContext)) + #expect(state.authenticationMessageReceived(.ok) == .wait) } - func testAuthenticateMD5WithoutPassword() { + @Test func testAuthenticateMD5WithoutPassword() { let authContext = AuthContext(username: "test", password: nil, database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) let salt: UInt32 = 0x00_01_02_03 - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(.md5(salt: salt)) == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .authMechanismRequiresPassword, closePromise: nil))) } - func testAuthenticateOkAfterStartUpWithoutAuthChallenge() { + @Test func testAuthenticateOkAfterStartUpWithoutAuthChallenge() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(.ok) == .wait) } - func testAuthenticateSCRAMSHA256WithAtypicalEncoding() { + @Test func testAuthenticateSCRAMSHA256WithAtypicalEncoding() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + let saslResponse = state.authenticationMessageReceived(.sasl(names: ["SCRAM-SHA-256"])) guard case .sendSaslInitialResponse(name: let name, initialResponse: let responseData) = saslResponse else { - return XCTFail("\(saslResponse) is not .sendSaslInitialResponse") + Issue.record("\(saslResponse) is not .sendSaslInitialResponse") + return } let responseString = String(decoding: responseData, as: UTF8.self) - XCTAssertEqual(name, "SCRAM-SHA-256") - XCTAssert(responseString.starts(with: "n,,n=test,r=")) - + #expect(name == "SCRAM-SHA-256") + #expect(responseString.starts(with: "n,,n=test,r=")) + let saslContinueResponse = state.authenticationMessageReceived(.saslContinue(data: .init(bytes: "r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,s=ijgUVaWgCDLRJyF963BKNA==,i=4096".utf8 ))) guard case .sendSaslResponse(let responseData2) = saslContinueResponse else { - return XCTFail("\(saslContinueResponse) is not .sendSaslResponse") + Issue.record("\(saslContinueResponse) is not .sendSaslResponse") + return } let response2String = String(decoding: responseData2, as: UTF8.self) - XCTAssertEqual(response2String.prefix(76), "c=biws,r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,p=") + #expect(response2String.prefix(76) == "c=biws,r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,p=") } - func testAuthenticationFailure() { + @Test func testAuthenticationFailure() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) let salt: UInt32 = 0x00_01_02_03 - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(.md5(salt: salt)) == .sendPasswordMessage(.md5(salt: salt), authContext)) let fields: [PostgresBackendMessage.Field: String] = [ .message: "password authentication failed for user \"postgres\"", .severity: "FATAL", @@ -86,13 +88,13 @@ class AuthenticationStateMachineTests: XCTestCase { .line: "334", .file: "auth.c" ] - XCTAssertEqual(state.errorReceived(.init(fields: fields)), + #expect(state.errorReceived(.init(fields: fields)) == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .server(.init(fields: fields)), closePromise: nil))) } // MARK: Test unsupported messages - func testUnsupportedAuthMechanism() { + @Test func testUnsupportedAuthMechanism() { let unsupported: [(PostgresBackendMessage.Authentication, PSQLError.UnsupportedAuthScheme)] = [ (.kerberosV5, .kerberosV5), (.scmCredential, .scmCredential), @@ -104,14 +106,14 @@ class AuthenticationStateMachineTests: XCTestCase { for (message, mechanism) in unsupported { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(message), + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(message) == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unsupportedAuthMechanism(mechanism), closePromise: nil))) } } - func testUnexpectedMessagesAfterStartUp() { + @Test func testUnexpectedMessagesAfterStartUp() { var buffer = ByteBuffer() buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8]) let unexpected: [PostgresBackendMessage.Authentication] = [ @@ -123,14 +125,14 @@ class AuthenticationStateMachineTests: XCTestCase { for message in unexpected { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(message), + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(message) == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil))) } } - func testUnexpectedMessagesAfterPasswordSent() { + @Test func testUnexpectedMessagesAfterPasswordSent() { let salt: UInt32 = 0x00_01_02_03 var buffer = ByteBuffer() buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8]) @@ -150,10 +152,10 @@ class AuthenticationStateMachineTests: XCTestCase { for message in unexpected { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) - XCTAssertEqual(state.authenticationMessageReceived(message), + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(.md5(salt: salt)) == .sendPasswordMessage(.md5(salt: salt), authContext)) + #expect(state.authenticationMessageReceived(message) == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil))) } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index f3d72a5e..445feb25 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -1,162 +1,159 @@ -import XCTest +import Testing @testable import PostgresNIO @testable import NIOCore import NIOPosix import NIOSSL -class ConnectionStateMachineTests: XCTestCase { - - func testStartup() { +@Suite struct ConnectionStateMachineTests { + + @Test func testStartup() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(.plaintext) == .sendPasswordMessage(.cleartext, authContext)) + #expect(state.authenticationMessageReceived(.ok) == .wait) } - func testSSLStartupSuccess() { + @Test func testSSLStartupSuccess() { let authContext = AuthContext(username: "test", password: "abc123", database: "test") var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) - XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 0), .establishSSLConnection) - XCTAssertEqual(state.sslHandlerAdded(), .wait) - XCTAssertEqual(state.sslEstablished(), .provideAuthenticationContext) - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) + #expect(state.connected(tls: .require) == .sendSSLRequest) + #expect(state.sslSupportedReceived(unprocessedBytes: 0) == .establishSSLConnection) + #expect(state.sslHandlerAdded() == .wait) + #expect(state.sslEstablished() == .provideAuthenticationContext) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) let salt: UInt32 = 0x00_01_02_03 - XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) + #expect(state.authenticationMessageReceived(.md5(salt: salt)) == .sendPasswordMessage(.md5(salt: salt), authContext)) } - func testSSLStartupFailureTooManyBytesRemaining() { + @Test func testSSLStartupFailureTooManyBytesRemaining() { var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) + #expect(state.connected(tls: .require) == .sendSSLRequest) let failError = PSQLError.receivedUnencryptedDataAfterSSLRequest - XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 1), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) + #expect(state.sslSupportedReceived(unprocessedBytes: 1) == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) } - func testSSLStartupFailHandler() { + @Test func testSSLStartupFailHandler() { struct SSLHandlerAddError: Error, Equatable {} var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) - XCTAssertEqual(state.sslSupportedReceived(unprocessedBytes: 0), .establishSSLConnection) + #expect(state.connected(tls: .require) == .sendSSLRequest) + #expect(state.sslSupportedReceived(unprocessedBytes: 0) == .establishSSLConnection) let failError = PSQLError.failedToAddSSLHandler(underlying: SSLHandlerAddError()) - XCTAssertEqual(state.errorHappened(failError), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) + #expect(state.errorHappened(failError) == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil))) } - func testTLSRequiredStartupSSLUnsupported() { + @Test func testTLSRequiredStartupSSLUnsupported() { var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest) - XCTAssertEqual(state.sslUnsupportedReceived(), + #expect(state.connected(tls: .require) == .sendSSLRequest) + #expect(state.sslUnsupportedReceived() == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.sslUnsupported, closePromise: nil))) } - func testTLSPreferredStartupSSLUnsupported() { + @Test func testTLSPreferredStartupSSLUnsupported() { var state = ConnectionStateMachine(requireBackendKeyData: true) - XCTAssertEqual(state.connected(tls: .prefer), .sendSSLRequest) - XCTAssertEqual(state.sslUnsupportedReceived(), .provideAuthenticationContext) + #expect(state.connected(tls: .prefer) == .sendSSLRequest) + #expect(state.sslUnsupportedReceived() == .provideAuthenticationContext) } - func testParameterStatusReceivedAndBackendKeyAfterAuthenticated() { + @Test func testParameterStatusReceivedAndBackendKeyAfterAuthenticated() { var state = ConnectionStateMachine(.authenticated(nil, [:])) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) - - XCTAssertEqual(state.backendKeyDataReceived(.init(processID: 2730, secretKey: 882037977)), .wait) - XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + #expect(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "application_name", value: "")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")) == .wait) + + #expect(state.backendKeyDataReceived(.init(processID: 2730, secretKey: 882037977)) == .wait) + #expect(state.readyForQueryReceived(.idle) == .fireEventReadyForQuery) } - func testBackendKeyAndParameterStatusReceivedAfterAuthenticated() { + @Test func testBackendKeyAndParameterStatusReceivedAfterAuthenticated() { var state = ConnectionStateMachine(.authenticated(nil, [:])) - XCTAssertEqual(state.backendKeyDataReceived(.init(processID: 2730, secretKey: 882037977)), .wait) - - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) - - XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + #expect(state.backendKeyDataReceived(.init(processID: 2730, secretKey: 882037977)) == .wait) + + #expect(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "application_name", value: "")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")) == .wait) + + #expect(state.readyForQueryReceived(.idle) == .fireEventReadyForQuery) } - func testReadyForQueryReceivedWithoutBackendKeyAfterAuthenticated() { + @Test func testReadyForQueryReceivedWithoutBackendKeyAfterAuthenticated() { var state = ConnectionStateMachine(.authenticated(nil, [:]), requireBackendKeyData: true) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) - - XCTAssertEqual(state.readyForQueryReceived(.idle), + #expect(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "application_name", value: "")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")) == .wait) + + #expect(state.readyForQueryReceived(.idle) == .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.unexpectedBackendMessage(.readyForQuery(.idle)), closePromise: nil))) } - func testReadyForQueryReceivedWithoutUnneededBackendKeyAfterAuthenticated() { + @Test func testReadyForQueryReceivedWithoutUnneededBackendKeyAfterAuthenticated() { var state = ConnectionStateMachine(.authenticated(nil, [:]), requireBackendKeyData: false) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait) - XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait) - - XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + #expect(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "application_name", value: "")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")) == .wait) + #expect(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")) == .wait) + + #expect(state.readyForQueryReceived(.idle) == .fireEventReadyForQuery) } - func testErrorIsIgnoredWhenClosingConnection() { + @Test func testErrorIsIgnoredWhenClosingConnection() { // test ignore unclean shutdown when closing connection var stateIgnoreChannelError = ConnectionStateMachine(.closing(nil)) - XCTAssertEqual(stateIgnoreChannelError.errorHappened(.connectionError(underlying: NIOSSLError.uncleanShutdown)), .wait) - XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive) - + #expect(stateIgnoreChannelError.errorHappened(.connectionError(underlying: NIOSSLError.uncleanShutdown)) == .wait) + #expect(stateIgnoreChannelError.closed() == .fireChannelInactive) + // test ignore any other error when closing connection var stateIgnoreErrorMessage = ConnectionStateMachine(.closing(nil)) - XCTAssertEqual(stateIgnoreErrorMessage.errorReceived(.init(fields: [:])), .wait) - XCTAssertEqual(stateIgnoreErrorMessage.closed(), .fireChannelInactive) + #expect(stateIgnoreErrorMessage.errorReceived(.init(fields: [:])) == .wait) + #expect(stateIgnoreErrorMessage.closed() == .fireChannelInactive) } - func testFailQueuedQueriesOnAuthenticationFailure() throws { - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } - + @Test func testFailQueuedQueriesOnAuthenticationFailure() throws { let authContext = AuthContext(username: "test", password: "abc123", database: "test") let salt: UInt32 = 0x00_01_02_03 - let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self) + let queryPromise = NIOSingletons.posixEventLoopGroup.next().makePromise(of: PSQLRowStream.self) var state = ConnectionStateMachine(requireBackendKeyData: true) let extendedQueryContext = ExtendedQueryContext( @@ -164,10 +161,10 @@ class ConnectionStateMachineTests: XCTestCase { logger: .psqlTest, promise: queryPromise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(extendedQueryContext)), .wait) - XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) + #expect(state.enqueue(task: .extendedQuery(extendedQueryContext)) == .wait) + #expect(state.connected(tls: .disable) == .provideAuthenticationContext) + #expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext)) + #expect(state.authenticationMessageReceived(.md5(salt: salt)) == .sendPasswordMessage(.md5(salt: salt), authContext)) let fields: [PostgresBackendMessage.Field: String] = [ .message: "password authentication failed for user \"postgres\"", .severity: "FATAL", @@ -177,10 +174,10 @@ class ConnectionStateMachineTests: XCTestCase { .line: "334", .file: "auth.c" ] - XCTAssertEqual(state.errorReceived(.init(fields: fields)), + #expect(state.errorReceived(.init(fields: fields)) == .closeConnectionAndCleanup(.init(action: .close, tasks: [.extendedQuery(extendedQueryContext)], error: .server(.init(fields: fields)), closePromise: nil))) - XCTAssertNil(queryPromise.futureResult._value) + #expect(queryPromise.futureResult._value == nil) // make sure we don't crash queryPromise.fail(PSQLError.server(.init(fields: fields))) diff --git a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift index bfffef52..1602cee0 100644 --- a/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Array+PSQLCodableTests.swift @@ -1,137 +1,141 @@ -import XCTest +import Foundation +import Testing import NIOCore @testable import PostgresNIO -class Array_PSQLCodableTests: XCTestCase { +@Suite struct Array_PSQLCodableTests { - func testArrayTypes() { + @Test func testArrayTypes() { + #expect(Bool.psqlArrayType == .boolArray) + #expect(Bool.psqlType == .bool) + #expect([Bool].psqlType == .boolArray) - XCTAssertEqual(Bool.psqlArrayType, .boolArray) - XCTAssertEqual(Bool.psqlType, .bool) - XCTAssertEqual([Bool].psqlType, .boolArray) + #expect(ByteBuffer.psqlArrayType == .byteaArray) + #expect(ByteBuffer.psqlType == .bytea) + #expect([ByteBuffer].psqlType == .byteaArray) - XCTAssertEqual(ByteBuffer.psqlArrayType, .byteaArray) - XCTAssertEqual(ByteBuffer.psqlType, .bytea) - XCTAssertEqual([ByteBuffer].psqlType, .byteaArray) + #expect(UInt8.psqlArrayType == .charArray) + #expect(UInt8.psqlType == .char) + #expect([UInt8].psqlType == .charArray) - XCTAssertEqual(UInt8.psqlArrayType, .charArray) - XCTAssertEqual(UInt8.psqlType, .char) - XCTAssertEqual([UInt8].psqlType, .charArray) + #expect(Int16.psqlArrayType == .int2Array) + #expect(Int16.psqlType == .int2) + #expect([Int16].psqlType == .int2Array) - XCTAssertEqual(Int16.psqlArrayType, .int2Array) - XCTAssertEqual(Int16.psqlType, .int2) - XCTAssertEqual([Int16].psqlType, .int2Array) + #expect(Int32.psqlArrayType == .int4Array) + #expect(Int32.psqlType == .int4) + #expect([Int32].psqlType == .int4Array) - XCTAssertEqual(Int32.psqlArrayType, .int4Array) - XCTAssertEqual(Int32.psqlType, .int4) - XCTAssertEqual([Int32].psqlType, .int4Array) - - XCTAssertEqual(Int64.psqlArrayType, .int8Array) - XCTAssertEqual(Int64.psqlType, .int8) - XCTAssertEqual([Int64].psqlType, .int8Array) + #expect(Int64.psqlArrayType == .int8Array) + #expect(Int64.psqlType == .int8) + #expect([Int64].psqlType == .int8Array) #if (arch(i386) || arch(arm)) - XCTAssertEqual(Int.psqlArrayType, .int4Array) - XCTAssertEqual(Int.psqlType, .int4) - XCTAssertEqual([Int].psqlType, .int4Array) + #expect(Int.psqlArrayType == .int4Array) + #expect(Int.psqlType == .int4) + #expect([Int].psqlType == .int4Array) #else - XCTAssertEqual(Int.psqlArrayType, .int8Array) - XCTAssertEqual(Int.psqlType, .int8) - XCTAssertEqual([Int].psqlType, .int8Array) + #expect(Int.psqlArrayType == .int8Array) + #expect(Int.psqlType == .int8) + #expect([Int].psqlType == .int8Array) #endif - XCTAssertEqual(Float.psqlArrayType, .float4Array) - XCTAssertEqual(Float.psqlType, .float4) - XCTAssertEqual([Float].psqlType, .float4Array) + #expect(Float.psqlArrayType == .float4Array) + #expect(Float.psqlType == .float4) + #expect([Float].psqlType == .float4Array) - XCTAssertEqual(Double.psqlArrayType, .float8Array) - XCTAssertEqual(Double.psqlType, .float8) - XCTAssertEqual([Double].psqlType, .float8Array) + #expect(Double.psqlArrayType == .float8Array) + #expect(Double.psqlType == .float8) + #expect([Double].psqlType == .float8Array) - XCTAssertEqual(String.psqlArrayType, .textArray) - XCTAssertEqual(String.psqlType, .text) - XCTAssertEqual([String].psqlType, .textArray) + #expect(String.psqlArrayType == .textArray) + #expect(String.psqlType == .text) + #expect([String].psqlType == .textArray) - XCTAssertEqual(UUID.psqlArrayType, .uuidArray) - XCTAssertEqual(UUID.psqlType, .uuid) - XCTAssertEqual([UUID].psqlType, .uuidArray) + #expect(UUID.psqlArrayType == .uuidArray) + #expect(UUID.psqlType == .uuid) + #expect([UUID].psqlType == .uuidArray) - XCTAssertEqual(Date.psqlArrayType, .timestamptzArray) - XCTAssertEqual(Date.psqlType, .timestamptz) - XCTAssertEqual([Date].psqlType, .timestamptzArray) + #expect(Date.psqlArrayType == .timestamptzArray) + #expect(Date.psqlType == .timestamptz) + #expect([Date].psqlType == .timestamptzArray) - XCTAssertEqual(Range.psqlArrayType, .int4RangeArray) - XCTAssertEqual(Range.psqlType, .int4Range) - XCTAssertEqual([Range].psqlType, .int4RangeArray) + #expect(Range.psqlArrayType == .int4RangeArray) + #expect(Range.psqlType == .int4Range) + #expect([Range].psqlType == .int4RangeArray) - XCTAssertEqual(ClosedRange.psqlArrayType, .int4RangeArray) - XCTAssertEqual(ClosedRange.psqlType, .int4Range) - XCTAssertEqual([ClosedRange].psqlType, .int4RangeArray) + #expect(ClosedRange.psqlArrayType == .int4RangeArray) + #expect(ClosedRange.psqlType == .int4Range) + #expect([ClosedRange].psqlType == .int4RangeArray) - XCTAssertEqual(Range.psqlArrayType, .int8RangeArray) - XCTAssertEqual(Range.psqlType, .int8Range) - XCTAssertEqual([Range].psqlType, .int8RangeArray) + #expect(Range.psqlArrayType == .int8RangeArray) + #expect(Range.psqlType == .int8Range) + #expect([Range].psqlType == .int8RangeArray) - XCTAssertEqual(ClosedRange.psqlArrayType, .int8RangeArray) - XCTAssertEqual(ClosedRange.psqlType, .int8Range) - XCTAssertEqual([ClosedRange].psqlType, .int8RangeArray) + #expect(ClosedRange.psqlArrayType == .int8RangeArray) + #expect(ClosedRange.psqlType == .int8Range) + #expect([ClosedRange].psqlType == .int8RangeArray) } - func testStringArrayRoundTrip() { + @Test func testStringArrayRoundTrip() { let values = ["foo", "bar", "hello", "world"] var buffer = ByteBuffer() values.encode(into: &buffer, context: .default) var result: [String]? - XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) - XCTAssertEqual(values, result) + #expect(throws: Never.self) { + result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default) + } + #expect(values == result) } - func testEmptyStringArrayRoundTrip() { + @Test func testEmptyStringArrayRoundTrip() { let values: [String] = [] var buffer = ByteBuffer() values.encode(into: &buffer, context: .default) var result: [String]? - XCTAssertNoThrow(result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) - XCTAssertEqual(values, result) + #expect(throws: Never.self) { + result = try [String](from: &buffer, type: .textArray, format: .binary, context: .default) + } + #expect(values == result) } - func testDecodeFailureIsNotEmptyOutOfScope() { + @Test func testDecodeFailureIsNotEmptyOutOfScope() { var buffer = ByteBuffer() buffer.writeInteger(Int32(2)) // invalid value buffer.writeInteger(Int32(0)) buffer.writeInteger(String.psqlType.rawValue) - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeFailureSecondValueIsUnexpected() { + @Test func testDecodeFailureSecondValueIsUnexpected() { var buffer = ByteBuffer() buffer.writeInteger(Int32(0)) // is empty buffer.writeInteger(Int32(1)) // invalid value, must always be 0 buffer.writeInteger(String.psqlType.rawValue) - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeFailureTriesDecodeInt8() { + @Test func testDecodeFailureTriesDecodeInt8() { let value: Int64 = 1 << 32 var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeFailureInvalidNumberOfArrayElements() { + @Test func testDecodeFailureInvalidNumberOfArrayElements() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // invalid value buffer.writeInteger(Int32(0)) @@ -139,12 +143,12 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(-123)) // expected element count buffer.writeInteger(Int32(1)) // dimensions... must be one - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeFailureInvalidNumberOfDimensions() { + @Test func testDecodeFailureInvalidNumberOfDimensions() { var buffer = ByteBuffer() buffer.writeInteger(Int32(1)) // invalid value buffer.writeInteger(Int32(0)) @@ -152,12 +156,12 @@ class Array_PSQLCodableTests: XCTestCase { buffer.writeInteger(Int32(1)) // expected element count buffer.writeInteger(Int32(2)) // dimensions... must be one - XCTAssertThrowsError(try [String](from: &buffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &buffer, type: .textArray, format: .binary, context: .default) } } - func testDecodeUnexpectedEnd() { + @Test func testDecodeUnexpectedEnd() { var unexpectedEndInElementLengthBuffer = ByteBuffer() unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // invalid value unexpectedEndInElementLengthBuffer.writeInteger(Int32(0)) @@ -166,8 +170,8 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementLengthBuffer.writeInteger(Int32(1)) // dimensions unexpectedEndInElementLengthBuffer.writeInteger(Int16(1)) // length of element, must be Int32 - XCTAssertThrowsError(try [String](from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &unexpectedEndInElementLengthBuffer, type: .textArray, format: .binary, context: .default) } var unexpectedEndInElementBuffer = ByteBuffer() @@ -179,8 +183,8 @@ class Array_PSQLCodableTests: XCTestCase { unexpectedEndInElementBuffer.writeInteger(Int32(12)) // length of element, must be Int32 unexpectedEndInElementBuffer.writeString("Hello World") // only 11 bytes, 12 needed! - XCTAssertThrowsError(try [String](from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try [String](from: &unexpectedEndInElementBuffer, type: .textArray, format: .binary, context: .default) } } } diff --git a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift index e6e43f0b..d23eff08 100644 --- a/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bool+PSQLCodableTests.swift @@ -1,89 +1,97 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class Bool_PSQLCodableTests: XCTestCase { +@Suite struct Bool_PSQLCodableTests { // MARK: - Binary - func testBinaryTrueRoundTrip() { + @Test func testBinaryTrueRoundTrip() { let value = true var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(Bool.psqlType, .bool) - XCTAssertEqual(Bool.psqlFormat, .binary) - XCTAssertEqual(buffer.readableBytes, 1) - XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 1) + #expect(Bool.psqlType == .bool) + #expect(Bool.psqlFormat == .binary) + #expect(buffer.readableBytes == 1) + #expect(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self) == 1) var result: Bool? - XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) - XCTAssertEqual(value, result) + #expect(throws: Never.self) { + result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default) + } + #expect(value == result) } - func testBinaryFalseRoundTrip() { + @Test func testBinaryFalseRoundTrip() { let value = false var buffer = ByteBuffer() value.encode(into: &buffer, context: .default) - XCTAssertEqual(Bool.psqlType, .bool) - XCTAssertEqual(Bool.psqlFormat, .binary) - XCTAssertEqual(buffer.readableBytes, 1) - XCTAssertEqual(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self), 0) + #expect(Bool.psqlType == .bool) + #expect(Bool.psqlFormat == .binary) + #expect(buffer.readableBytes == 1) + #expect(buffer.getInteger(at: buffer.readerIndex, as: UInt8.self) == 0) var result: Bool? - XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) - XCTAssertEqual(value, result) + #expect(throws: Never.self) { + result = try Bool(from: &buffer, type: .bool, format: .binary, context: .default) + } + #expect(value == result) } - func testBinaryDecodeBoolInvalidLength() { + @Test func testBinaryDecodeBoolInvalidLength() { var buffer = ByteBuffer() buffer.writeInteger(Int64(1)) - XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try Bool(from: &buffer, type: .bool, format: .binary, context: .default) } } - func testBinaryDecodeBoolInvalidValue() { + @Test func testBinaryDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try Bool(from: &buffer, type: .bool, format: .binary, context: .default) } } // MARK: - Text - func testTextTrueDecode() { + @Test func testTextTrueDecode() { let value = true var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "t")) var result: Bool? - XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) - XCTAssertEqual(value, result) + #expect(throws: Never.self) { + result = try Bool(from: &buffer, type: .bool, format: .text, context: .default) + } + #expect(value == result) } - func testTextFalseDecode() { + @Test func testTextFalseDecode() { let value = false var buffer = ByteBuffer() buffer.writeInteger(UInt8(ascii: "f")) var result: Bool? - XCTAssertNoThrow(result = try Bool(from: &buffer, type: .bool, format: .text, context: .default)) - XCTAssertEqual(value, result) + #expect(throws: Never.self) { + result = try Bool(from: &buffer, type: .bool, format: .text, context: .default) + } + #expect(value == result) } - func testTextDecodeBoolInvalidValue() { + @Test func testTextDecodeBoolInvalidValue() { var buffer = ByteBuffer() buffer.writeInteger(UInt8(13)) - XCTAssertThrowsError(try Bool(from: &buffer, type: .bool, format: .text, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) + #expect(throws: PostgresDecodingError.Code.failure) { + try Bool(from: &buffer, type: .bool, format: .text, context: .default) } } } diff --git a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift index 9230aee7..77051775 100644 --- a/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Bytes+PSQLCodableTests.swift @@ -1,34 +1,35 @@ -import XCTest +import struct Foundation.Data +import Testing import NIOCore @testable import PostgresNIO -class Bytes_PSQLCodableTests: XCTestCase { - - func testDataRoundTrip() { +@Suite struct Bytes_PSQLCodableTests { + + @Test func testDataRoundTrip() { let data = Data((0...UInt8.max)) var buffer = ByteBuffer() data.encode(into: &buffer, context: .default) - XCTAssertEqual(ByteBuffer.psqlType, .bytea) - + #expect(ByteBuffer.psqlType == .bytea) + var result: Data? result = Data(from: &buffer, type: .bytea, format: .binary, context: .default) - XCTAssertEqual(data, result) + #expect(data == result) } - func testByteBufferRoundTrip() { + @Test func testByteBufferRoundTrip() { let bytes = ByteBuffer(bytes: (0...UInt8.max)) var buffer = ByteBuffer() bytes.encode(into: &buffer, context: .default) - XCTAssertEqual(ByteBuffer.psqlType, .bytea) - + #expect(ByteBuffer.psqlType == .bytea) + var result: ByteBuffer? result = ByteBuffer(from: &buffer, type: .bytea, format: .binary, context: .default) - XCTAssertEqual(bytes, result) + #expect(bytes == result) } - func testEncodeSequenceWhereElementUInt8() { + @Test func testEncodeSequenceWhereElementUInt8() { struct ByteSequence: Sequence, PostgresEncodable { typealias Element = UInt8 typealias Iterator = Array.Iterator @@ -47,7 +48,7 @@ class Bytes_PSQLCodableTests: XCTestCase { let sequence = ByteSequence() var buffer = ByteBuffer() sequence.encode(into: &buffer, context: .default) - XCTAssertEqual(ByteSequence.psqlType, .bytea) - XCTAssertEqual(buffer.readableBytes, 256) + #expect(ByteSequence.psqlType == .bytea) + #expect(buffer.readableBytes == 256) } } diff --git a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift index 06e39aae..3b857157 100644 --- a/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore import NIOTestUtils @testable import PostgresNIO -class AuthenticationTests: XCTestCase { - +@Suite struct AuthenticationTests { + func testDecodeAuthentication() { var expected = [PostgresBackendMessage]() var buffer = ByteBuffer() @@ -39,9 +39,11 @@ class AuthenticationTests: XCTestCase { encoder.encode(data: .authentication(.sspi), out: &buffer) expected.append(.authentication(.sspi)) - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: [(buffer, expected)], - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } - )) + #expect(throws: Never.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + ) + } } } diff --git a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift index d41607e3..204b544d 100644 --- a/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BackendKeyDataTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore import NIOTestUtils @testable import PostgresNIO -class BackendKeyDataTests: XCTestCase { - func testDecode() { +@Suite struct BackendKeyDataTests { + @Test func testDecode() { let buffer = ByteBuffer.backendMessage(id: .backendKeyData) { buffer in buffer.writeInteger(Int32(1234)) buffer.writeInteger(Int32(4567)) @@ -14,12 +14,15 @@ class BackendKeyDataTests: XCTestCase { (buffer, [PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), ] - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: expectedInOuts, - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + #expect(throws: Never.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expectedInOuts, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + ) + } } - func testDecodeInvalidLength() { + @Test func testDecodeInvalidLength() { var buffer = ByteBuffer() buffer.psqlWriteBackendMessageID(.backendKeyData) buffer.writeInteger(Int32(11)) @@ -30,10 +33,11 @@ class BackendKeyDataTests: XCTestCase { (buffer, [PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 4567))]), ] - XCTAssertThrowsError(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: expected, - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) { - XCTAssert($0 is PostgresMessageDecodingError) + #expect(throws: PostgresMessageDecodingError.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expected, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + ) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/BindTests.swift b/Tests/PostgresNIOTests/New/Messages/BindTests.swift index d5ec5b30..24925fdf 100644 --- a/Tests/PostgresNIOTests/New/Messages/BindTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/BindTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class BindTests: XCTestCase { - - func testEncodeBind() { +@Suite struct BindTests { + + @Test func testEncodeBind() { var bindings = PostgresBindings() bindings.append("Hello", context: .default) bindings.append("World", context: .default) @@ -14,34 +14,34 @@ class BindTests: XCTestCase { encoder.bind(portalName: "", preparedStatementName: "", bind: bindings) var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 37) - XCTAssertEqual(PostgresFrontendMessage.ID.bind.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 36) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 37) + #expect(PostgresFrontendMessage.ID.bind.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(byteBuffer.readInteger(as: Int32.self) == 36) + #expect("" == byteBuffer.readNullTerminatedString()) + #expect("" == byteBuffer.readNullTerminatedString()) // the number of parameters - XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) + #expect(2 == byteBuffer.readInteger(as: Int16.self)) // all (two) parameters have the same format (binary) - XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) - XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) - + #expect(1 == byteBuffer.readInteger(as: Int16.self)) + #expect(1 == byteBuffer.readInteger(as: Int16.self)) + // read number of parameters - XCTAssertEqual(2, byteBuffer.readInteger(as: Int16.self)) - + #expect(2 == byteBuffer.readInteger(as: Int16.self)) + // hello length - XCTAssertEqual(5, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual("Hello", byteBuffer.readString(length: 5)) - + #expect(5 == byteBuffer.readInteger(as: Int32.self)) + #expect("Hello" == byteBuffer.readString(length: 5)) + // world length - XCTAssertEqual(5, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual("World", byteBuffer.readString(length: 5)) - + #expect(5 == byteBuffer.readInteger(as: Int32.self)) + #expect("World" == byteBuffer.readString(length: 5)) + // all response values have the same format: therefore one format byte is next - XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) + #expect(1 == byteBuffer.readInteger(as: Int16.self)) // all response values have the same format (binary) - XCTAssertEqual(1, byteBuffer.readInteger(as: Int16.self)) - + #expect(1 == byteBuffer.readInteger(as: Int16.self)) + // nothing left to read - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift index 5548aae3..c2da01d3 100644 --- a/Tests/PostgresNIOTests/New/Messages/CancelTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CancelTests.swift @@ -1,21 +1,20 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class CancelTests: XCTestCase { - - func testEncodeCancel() { +@Suite struct CancelTests { + @Test func testEncodeCancel() { let processID: Int32 = 1234 let secretKey: Int32 = 4567 var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.cancel(processID: processID, secretKey: secretKey) var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 16) - XCTAssertEqual(16, byteBuffer.readInteger(as: Int32.self)) // payload length - XCTAssertEqual(80877102, byteBuffer.readInteger(as: Int32.self)) // cancel request code - XCTAssertEqual(processID, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(secretKey, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 16) + #expect(16 == byteBuffer.readInteger(as: Int32.self)) // payload length + #expect(80877102 == byteBuffer.readInteger(as: Int32.self)) // cancel request code + #expect(processID == byteBuffer.readInteger(as: Int32.self)) + #expect(secretKey == byteBuffer.readInteger(as: Int32.self)) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift index a8e1cfeb..9d6f1b37 100644 --- a/Tests/PostgresNIOTests/New/Messages/CloseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CloseTests.swift @@ -1,31 +1,31 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class CloseTests: XCTestCase { - func testEncodeClosePortal() { +@Suite struct CloseTests { + @Test func testEncodeClosePortal() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.closePortal("Hello") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 12) - XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 12) + #expect(PostgresFrontendMessage.ID.close.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(11 == byteBuffer.readInteger(as: Int32.self)) + #expect(UInt8(ascii: "P") == byteBuffer.readInteger(as: UInt8.self)) + #expect("Hello" == byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 0) } - func testEncodeCloseUnnamedStatement() { + @Test func testEncodeCloseUnnamedStatement() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.closePreparedStatement("") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 7) - XCTAssertEqual(PostgresFrontendMessage.ID.close.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 7) + #expect(PostgresFrontendMessage.ID.close.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(6 == byteBuffer.readInteger(as: Int32.self)) + #expect(UInt8(ascii: "S") == byteBuffer.readInteger(as: UInt8.self)) + #expect("" == byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/CopyTests.swift b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift index de686ae5..01136d05 100644 --- a/Tests/PostgresNIOTests/New/Messages/CopyTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore import NIOTestUtils @testable import PostgresNIO -class CopyTests: XCTestCase { - func testDecodeCopyInResponseMessage() throws { +@Suite struct CopyTests { + @Test func testDecodeCopyInResponseMessage() throws { let expected: [PostgresBackendMessage] = [ .copyInResponse(.init(format: .textual, columnFormats: [.textual, .textual])), .copyInResponse(.init(format: .binary, columnFormats: [.binary, .binary])), @@ -15,7 +15,8 @@ class CopyTests: XCTestCase { for message in expected { guard case .copyInResponse(let message) = message else { - return XCTFail("Expected only to get copyInResponse here!") + Issue.record("Expected only to get copyInResponse here!") + return } buffer.writeBackendMessage(id: .copyInResponse ) { buffer in buffer.writeInteger(Int8(message.format.rawValue)) @@ -31,72 +32,63 @@ class CopyTests: XCTestCase { ) } - func testDecodeFailureBecauseOfEmptyMessage() { + @Test func testDecodeFailureBecauseOfEmptyMessage() { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .copyInResponse) { _ in} - XCTAssertThrowsError( + #expect(throws: PostgresMessageDecodingError.self) { try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } ) - ) { - XCTAssert($0 is PostgresMessageDecodingError) } } - func testDecodeFailureBecauseOfInvalidFormat() { + @Test func testDecodeFailureBecauseOfInvalidFormat() { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .copyInResponse) { buffer in buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats } - - XCTAssertThrowsError( + + #expect(throws: PostgresMessageDecodingError.self) { try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } ) - ) { - XCTAssert($0 is PostgresMessageDecodingError) } } - func testDecodeFailureBecauseOfMissingColumnNumber() { + @Test func testDecodeFailureBecauseOfMissingColumnNumber() { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .copyInResponse) { buffer in buffer.writeInteger(Int8(0)) } - - XCTAssertThrowsError( + + #expect(throws: PostgresMessageDecodingError.self) { try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } ) - ) { - XCTAssert($0 is PostgresMessageDecodingError) } } - - func testDecodeFailureBecauseOfMissingColumns() { + @Test func testDecodeFailureBecauseOfMissingColumns() { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .copyInResponse) { buffer in buffer.writeInteger(Int8(0)) buffer.writeInteger(Int16(20)) // 20 columns promised, none given } - XCTAssertThrowsError( + #expect(throws: PostgresMessageDecodingError.self) { try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } ) - ) { - XCTAssert($0 is PostgresMessageDecodingError) } } - func testDecodeFailureBecauseOfInvalidColumnFormat() { + @Test func testDecodeFailureBecauseOfInvalidColumnFormat() { var buffer = ByteBuffer() buffer.writeBackendMessage(id: .copyInResponse) { buffer in buffer.writeInteger(Int8(0)) @@ -104,44 +96,42 @@ class CopyTests: XCTestCase { buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats } - XCTAssertThrowsError( + #expect(throws: PostgresMessageDecodingError.self) { try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: [(buffer, [])], decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } ) - ) { - XCTAssert($0 is PostgresMessageDecodingError) } } - func testEncodeCopyDataHeader() { + @Test func testEncodeCopyDataHeader() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.copyDataHeader(dataLength: 3) var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 5) - XCTAssertEqual(PostgresFrontendMessage.ID.copyData.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 7) + #expect(byteBuffer.readableBytes == 5) + #expect(PostgresFrontendMessage.ID.copyData.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(byteBuffer.readInteger(as: Int32.self) == 7) } - func testEncodeCopyDone() { + @Test func testEncodeCopyDone() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.copyDone() var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 5) - XCTAssertEqual(PostgresFrontendMessage.ID.copyDone.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 4) + #expect(byteBuffer.readableBytes == 5) + #expect(PostgresFrontendMessage.ID.copyDone.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(byteBuffer.readInteger(as: Int32.self) == 4) } - func testEncodeCopyFail() { + @Test func testEncodeCopyFail() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.copyFail(message: "Oh, no :(") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 15) - XCTAssertEqual(PostgresFrontendMessage.ID.copyFail.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 14) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "Oh, no :(") + #expect(byteBuffer.readableBytes == 15) + #expect(PostgresFrontendMessage.ID.copyFail.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(byteBuffer.readInteger(as: Int32.self) == 14) + #expect(byteBuffer.readNullTerminatedString() == "Oh, no :(") } } diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index a90d1e93..f185877a 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -1,10 +1,11 @@ -import XCTest +import Foundation +import Testing import NIOCore import NIOTestUtils @testable import PostgresNIO -class DataRowTests: XCTestCase { - func testDecode() { +@Suite struct DataRowTests { + @Test func testDecode() { let buffer = ByteBuffer.backendMessage(id: .dataRow) { buffer in // the data row has 3 columns buffer.writeInteger(3, as: Int16.self) @@ -26,23 +27,26 @@ class DataRowTests: XCTestCase { (buffer, [PostgresBackendMessage.dataRow(.init(columnCount: 3, bytes: rowSlice))]), ] - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: expectedInOuts, - decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) + #expect(throws: Never.self) { + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expectedInOuts, + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: false) } + ) + } } - func testIteratingElements() { + @Test func testIteratingElements() { let dataRow = DataRow.makeTestDataRow(nil, ByteBuffer(), ByteBuffer(repeating: 5, count: 10)) var iterator = dataRow.makeIterator() - XCTAssertEqual(dataRow.count, 3) - XCTAssertEqual(iterator.next(), .some(.none)) - XCTAssertEqual(iterator.next(), ByteBuffer()) - XCTAssertEqual(iterator.next(), ByteBuffer(repeating: 5, count: 10)) - XCTAssertEqual(iterator.next(), .none) + #expect(dataRow.count == 3) + #expect(iterator.next() == .some(.none)) + #expect(iterator.next() == ByteBuffer()) + #expect(iterator.next() == ByteBuffer(repeating: 5, count: 10)) + #expect(iterator.next() == .none) } - func testIndexAfterAndSubscript() { + @Test func testIndexAfterAndSubscript() { let dataRow = DataRow.makeTestDataRow( nil, ByteBuffer(), @@ -51,18 +55,18 @@ class DataRowTests: XCTestCase { ) var index = dataRow.startIndex - XCTAssertEqual(dataRow[index], .none) + #expect(dataRow[index] == .none) index = dataRow.index(after: index) - XCTAssertEqual(dataRow[index], ByteBuffer()) + #expect(dataRow[index] == ByteBuffer()) index = dataRow.index(after: index) - XCTAssertEqual(dataRow[index], ByteBuffer(repeating: 5, count: 10)) + #expect(dataRow[index] == ByteBuffer(repeating: 5, count: 10)) index = dataRow.index(after: index) - XCTAssertEqual(dataRow[index], .none) + #expect(dataRow[index] == .none) index = dataRow.index(after: index) - XCTAssertEqual(index, dataRow.endIndex) + #expect(index == dataRow.endIndex) } - func testIndexComparison() { + @Test func testIndexComparison() { let dataRow = DataRow.makeTestDataRow( nil, ByteBuffer(), @@ -73,18 +77,18 @@ class DataRowTests: XCTestCase { let startIndex = dataRow.startIndex let secondIndex = dataRow.index(after: startIndex) - XCTAssertLessThanOrEqual(startIndex, secondIndex) - XCTAssertLessThan(startIndex, secondIndex) - - XCTAssertGreaterThanOrEqual(secondIndex, startIndex) - XCTAssertGreaterThan(secondIndex, startIndex) - - XCTAssertFalse(secondIndex == startIndex) - XCTAssertEqual(secondIndex, secondIndex) - XCTAssertEqual(startIndex, startIndex) + #expect(startIndex <= secondIndex) + #expect(startIndex < secondIndex) + + #expect(secondIndex >= startIndex) + #expect(secondIndex > startIndex) + + #expect(secondIndex != startIndex) + #expect(secondIndex == secondIndex) + #expect(startIndex == startIndex) } - func testColumnSubscript() { + @Test func testColumnSubscript() { let dataRow = DataRow.makeTestDataRow( nil, ByteBuffer(), @@ -92,14 +96,14 @@ class DataRowTests: XCTestCase { nil ) - XCTAssertEqual(dataRow.count, 4) - XCTAssertEqual(dataRow[column: 0], .none) - XCTAssertEqual(dataRow[column: 1], ByteBuffer()) - XCTAssertEqual(dataRow[column: 2], ByteBuffer(repeating: 5, count: 10)) - XCTAssertEqual(dataRow[column: 3], .none) + #expect(dataRow.count == 4) + #expect(dataRow[column: 0] == .none) + #expect(dataRow[column: 1] == ByteBuffer()) + #expect(dataRow[column: 2] == ByteBuffer(repeating: 5, count: 10)) + #expect(dataRow[column: 3] == .none) } - func testWithContiguousStorageIfAvailable() { + @Test func testWithContiguousStorageIfAvailable() { let dataRow = DataRow.makeTestDataRow( nil, ByteBuffer(), @@ -107,9 +111,9 @@ class DataRowTests: XCTestCase { nil ) - XCTAssertNil(dataRow.withContiguousStorageIfAvailable { _ in - return XCTFail("DataRow does not have a contiguous storage") - }) + #expect(dataRow.withContiguousStorageIfAvailable { _ in + Issue.record("DataRow does not have a contiguous storage") + } == nil) } } diff --git a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift index cb3c745b..42a521aa 100644 --- a/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DescribeTests.swift @@ -1,33 +1,32 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class DescribeTests: XCTestCase { - - func testEncodeDescribePortal() { +@Suite struct DescribeTests { + @Test func testEncodeDescribePortal() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.describePortal("Hello") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 12) - XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(11, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(UInt8(ascii: "P"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("Hello", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 12) + #expect(PostgresFrontendMessage.ID.describe.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(11 == byteBuffer.readInteger(as: Int32.self)) + #expect(UInt8(ascii: "P") == byteBuffer.readInteger(as: UInt8.self)) + #expect("Hello" == byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 0) } - - func testEncodeDescribeUnnamedStatement() { + + @Test func testEncodeDescribeUnnamedStatement() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.describePreparedStatement("") var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 7) - XCTAssertEqual(PostgresFrontendMessage.ID.describe.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(6, byteBuffer.readInteger(as: Int32.self)) - XCTAssertEqual(UInt8(ascii: "S"), byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 7) + #expect(PostgresFrontendMessage.ID.describe.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(6 == byteBuffer.readInteger(as: Int32.self)) + #expect(UInt8(ascii: "S") == byteBuffer.readInteger(as: UInt8.self)) + #expect("" == byteBuffer.readNullTerminatedString()) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift index 834ad0dd..985ab10e 100644 --- a/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift @@ -1,18 +1,17 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class ExecuteTests: XCTestCase { - - func testEncodeExecute() { +@Suite struct ExecuteTests { + @Test func testEncodeExecute() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.execute(portalName: "", maxNumberOfRows: 0) var byteBuffer = encoder.flushBuffer() - XCTAssertEqual(byteBuffer.readableBytes, 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) - XCTAssertEqual(PostgresFrontendMessage.ID.execute.rawValue, byteBuffer.readInteger(as: UInt8.self)) - XCTAssertEqual(9, byteBuffer.readInteger(as: Int32.self)) // length - XCTAssertEqual("", byteBuffer.readNullTerminatedString()) - XCTAssertEqual(0, byteBuffer.readInteger(as: Int32.self)) + #expect(byteBuffer.readableBytes == 10) // 1 (id) + 4 (length) + 1 (empty null terminated string) + 4 (count) + #expect(PostgresFrontendMessage.ID.execute.rawValue == byteBuffer.readInteger(as: UInt8.self)) + #expect(9 == byteBuffer.readInteger(as: Int32.self)) // length + #expect("" == byteBuffer.readNullTerminatedString()) + #expect(0 == byteBuffer.readInteger(as: Int32.self)) } } diff --git a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift index 9f81e4e4..e40dbbfe 100644 --- a/Tests/PostgresNIOTests/New/Messages/ParseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/ParseTests.swift @@ -1,9 +1,9 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class ParseTests: XCTestCase { - func testEncode() { +@Suite struct ParseTests { + @Test func testEncode() { let preparedStatementName = "test" let query = "SELECT version()" let parameters: [PostgresDataType] = [.bool, .int8, .bytea, .varchar, .text, .uuid, .json, .jsonbArray] @@ -22,14 +22,14 @@ class ParseTests: XCTestCase { // + 4 preparedStatement (3 + 1 null terminator) // + 1 query () - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.parse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), preparedStatementName) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), query) - XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parameters.count)) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.parse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readNullTerminatedString() == preparedStatementName) + #expect(byteBuffer.readNullTerminatedString() == query) + #expect(byteBuffer.readInteger(as: UInt16.self) == UInt16(parameters.count)) for dataType in parameters { - XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), dataType.rawValue) + #expect(byteBuffer.readInteger(as: UInt32.self) == dataType.rawValue) } } } diff --git a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift index 4a4833d2..cf4ad83f 100644 --- a/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/PasswordTests.swift @@ -1,10 +1,9 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class PasswordTests: XCTestCase { - - func testEncodePassword() { +@Suite struct PasswordTests { + @Test func testEncodePassword() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) // md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4 let password = "md522d085ed8dc3377968dc1c1a40519a2a" @@ -13,9 +12,9 @@ class PasswordTests: XCTestCase { let expectedLength = 41 // 1 (id) + 4 (length) + 35 (string) + 1 (null termination) - XCTAssertEqual(byteBuffer.readableBytes, expectedLength) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.password.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(expectedLength - 1)) // length - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "md522d085ed8dc3377968dc1c1a40519a2a") + #expect(byteBuffer.readableBytes == expectedLength) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.password.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(expectedLength - 1)) // length + #expect(byteBuffer.readNullTerminatedString() == "md522d085ed8dc3377968dc1c1a40519a2a") } } diff --git a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift index 90aa6b34..7ba31057 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class SASLInitialResponseTests: XCTestCase { +@Suite struct SASLInitialResponseTests { - func testEncode() { + @Test func testEncode() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) let saslMechanism = "hello" let initialData: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] @@ -19,16 +19,16 @@ class SASLInitialResponseTests: XCTestCase { // + 4 initialData length // + 8 initialData - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(initialData.count)) - XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.saslInitialResponse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readNullTerminatedString() == saslMechanism) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(initialData.count)) + #expect(byteBuffer.readBytes(length: initialData.count) == initialData) + #expect(byteBuffer.readableBytes == 0) } - func testEncodeWithoutData() { + @Test func testEncodeWithoutData() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) let saslMechanism = "hello" let initialData: [UInt8] = [] @@ -43,12 +43,12 @@ class SASLInitialResponseTests: XCTestCase { // + 4 initialData length // + 0 initialData - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslInitialResponse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), saslMechanism) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(-1)) - XCTAssertEqual(byteBuffer.readBytes(length: initialData.count), initialData) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.saslInitialResponse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readNullTerminatedString() == saslMechanism) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(-1)) + #expect(byteBuffer.readBytes(length: initialData.count) == initialData) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift index cdb0f10b..a2a06418 100644 --- a/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift @@ -1,10 +1,10 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class SASLResponseTests: XCTestCase { +@Suite struct SASLResponseTests { - func testEncodeWithData() { + @Test func testEncodeWithData() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) let data: [UInt8] = [0, 1, 2, 3, 4, 5, 6, 7] encoder.saslResponse(data) @@ -12,14 +12,14 @@ class SASLResponseTests: XCTestCase { let length: Int = 1 + 4 + (data.count) - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readBytes(length: data.count), data) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.saslResponse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readBytes(length: data.count) == data) + #expect(byteBuffer.readableBytes == 0) } - func testEncodeWithoutData() { + @Test func testEncodeWithoutData() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) let data: [UInt8] = [] encoder.saslResponse(data) @@ -27,9 +27,9 @@ class SASLResponseTests: XCTestCase { let length: Int = 1 + 4 - XCTAssertEqual(byteBuffer.readableBytes, length) - XCTAssertEqual(byteBuffer.readInteger(as: UInt8.self), PostgresFrontendMessage.ID.saslResponse.rawValue) - XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1)) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == length) + #expect(byteBuffer.readInteger(as: UInt8.self) == PostgresFrontendMessage.ID.saslResponse.rawValue) + #expect(byteBuffer.readInteger(as: Int32.self) == Int32(length - 1)) + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 5af3bf34..23d022d9 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -1,10 +1,9 @@ -import XCTest +import Testing import NIOCore @testable import PostgresNIO -class StartupTests: XCTestCase { - - func testStartupMessageWithDatabase() { +@Suite struct StartupTests { + @Test func testStartupMessageWithDatabase() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() @@ -15,18 +14,18 @@ class StartupTests: XCTestCase { byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) - XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") - XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + #expect(byteBufferLength == byteBuffer.readInteger()) + #expect(PostgresFrontendMessage.Startup.versionThree == byteBuffer.readInteger()) + #expect(byteBuffer.readNullTerminatedString() == "user") + #expect(byteBuffer.readNullTerminatedString() == "test") + #expect(byteBuffer.readNullTerminatedString() == "database") + #expect(byteBuffer.readNullTerminatedString() == "abc123") + #expect(byteBuffer.readInteger() == UInt8(0)) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 0) } - func testStartupMessageWithoutDatabase() { + @Test func testStartupMessageWithoutDatabase() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() @@ -36,16 +35,16 @@ class StartupTests: XCTestCase { byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) - XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + #expect(byteBufferLength == byteBuffer.readInteger()) + #expect(PostgresFrontendMessage.Startup.versionThree == byteBuffer.readInteger()) + #expect(byteBuffer.readNullTerminatedString() == "user") + #expect(byteBuffer.readNullTerminatedString() == "test") + #expect(byteBuffer.readInteger() == UInt8(0)) - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBuffer.readableBytes == 0) } - func testStartupMessageWithAdditionalOptions() { + @Test func testStartupMessageWithAdditionalOptions() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) var byteBuffer = ByteBuffer() @@ -56,17 +55,17 @@ class StartupTests: XCTestCase { byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) - XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) - XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some") - XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") - XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) - - XCTAssertEqual(byteBuffer.readableBytes, 0) + #expect(byteBufferLength == byteBuffer.readInteger()) + #expect(PostgresFrontendMessage.Startup.versionThree == byteBuffer.readInteger()) + #expect(byteBuffer.readNullTerminatedString() == "user") + #expect(byteBuffer.readNullTerminatedString() == "test") + #expect(byteBuffer.readNullTerminatedString() == "database") + #expect(byteBuffer.readNullTerminatedString() == "abc123") + #expect(byteBuffer.readNullTerminatedString() == "some") + #expect(byteBuffer.readNullTerminatedString() == "options") + #expect(byteBuffer.readInteger() == UInt8(0)) + + #expect(byteBuffer.readableBytes == 0) } } diff --git a/Tests/PostgresNIOTests/New/PostgresCellTests.swift b/Tests/PostgresNIOTests/New/PostgresCellTests.swift index 6458d063..4ed64b3f 100644 --- a/Tests/PostgresNIOTests/New/PostgresCellTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresCellTests.swift @@ -1,9 +1,9 @@ @testable import PostgresNIO -import XCTest +import Testing import NIOCore -final class PostgresCellTests: XCTestCase { - func testDecodingANonOptionalString() { +@Suite struct PostgresCellTests { + @Test func testDecodingANonOptionalString() { let cell = PostgresCell( bytes: ByteBuffer(string: "Hello world"), dataType: .text, @@ -13,11 +13,13 @@ final class PostgresCellTests: XCTestCase { ) var result: String? - XCTAssertNoThrow(result = try cell.decode(String.self, context: .default)) - XCTAssertEqual(result, "Hello world") + #expect(throws: Never.self) { + result = try cell.decode(String.self, context: .default) + } + #expect(result == "Hello world") } - func testDecodingAnOptionalString() { + @Test func testDecodingAnOptionalString() { let cell = PostgresCell( bytes: nil, dataType: .text, @@ -27,11 +29,13 @@ final class PostgresCellTests: XCTestCase { ) var result: String? = "test" - XCTAssertNoThrow(result = try cell.decode(String?.self, context: .default)) - XCTAssertNil(result) + #expect(throws: Never.self) { + result = try cell.decode(String?.self, context: .default) + } + #expect(result == nil) } - func testDecodingFailure() { + @Test func testDecodingFailure() { let cell = PostgresCell( bytes: ByteBuffer(string: "Hello world"), dataType: .text, @@ -40,19 +44,44 @@ final class PostgresCellTests: XCTestCase { columnIndex: 1 ) - XCTAssertThrowsError(try cell.decode(Int?.self, context: .default)) { - guard let error = $0 as? PostgresDecodingError else { - return XCTFail("Unexpected error") + #if compiler(>=6.1) + let error = #expect(throws: PostgresDecodingError.self) { + try cell.decode(Int?.self, context: .default) + } + guard let error else { + Issue.record("Expected error at this point") + return + } + + #expect(error.file == #fileID) + #expect(error.line == #line - 9) + #expect(error.code == .typeMismatch) + #expect(error.columnName == "hello") + #expect(error.columnIndex == 1) + let correctType = error.targetType == Int?.self + #expect(correctType) + #expect(error.postgresType == .text) + #expect(error.postgresFormat == .binary) + #else + do { + _ = try cell.decode(Int?.self, context: .default) + Issue.record("Expected to throw") + } catch { + guard let error = error as? PostgresDecodingError else { + Issue.record("Expected error at this point") + return } - XCTAssertEqual(error.file, #fileID) - XCTAssertEqual(error.line, #line - 6) - XCTAssertEqual(error.code, .typeMismatch) - XCTAssertEqual(error.columnName, "hello") - XCTAssertEqual(error.columnIndex, 1) - XCTAssert(error.targetType == Int?.self) - XCTAssertEqual(error.postgresType, .text) - XCTAssertEqual(error.postgresFormat, .binary) + #expect(error.file == #fileID) + #expect(error.line == #line - 9) + #expect(error.code == .typeMismatch) + #expect(error.columnName == "hello") + #expect(error.columnIndex == 1) + let correctType = error.targetType == Int?.self + #expect(correctType) + #expect(error.postgresType == .text) + #expect(error.postgresFormat == .binary) } + #endif } } diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 9d662252..54f13e96 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -1,15 +1,15 @@ import Atomics import NIOEmbedded import NIOPosix -import XCTest +import Testing @testable import PostgresNIO import NIOCore import Logging -final class PostgresRowSequenceTests: XCTestCase { +@Suite struct PostgresRowSequenceTests { let logger = Logger(label: "PSQLRowStreamTests") - func testBackpressureWorks() async throws { + @Test func testBackpressureWorks() async throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -24,22 +24,22 @@ final class PostgresRowSequenceTests: XCTestCase { ) let rowSequence = stream.asyncSequence() - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) let dataRow: DataRow = [ByteBuffer(integer: Int64(1))] stream.receive([dataRow]) var iterator = rowSequence.makeAsyncIterator() let row = try await iterator.next() - XCTAssertEqual(dataSource.requestCount, 1) - XCTAssertEqual(row?.data, dataRow) + #expect(dataSource.requestCount == 1) + #expect(row?.data == dataRow) stream.receive(completion: .success("SELECT 1")) let empty = try await iterator.next() - XCTAssertNil(empty) + #expect(empty == nil) } - func testCancellationWorksWhileIterating() async throws { + @Test func testCancellationWorksWhileIterating() async throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -54,13 +54,13 @@ final class PostgresRowSequenceTests: XCTestCase { ) let rowSequence = stream.asyncSequence() - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self), counter) + #expect(try row.decode(Int.self) == counter) counter += 1 if counter == 64 { @@ -68,10 +68,10 @@ final class PostgresRowSequenceTests: XCTestCase { } } - XCTAssertEqual(dataSource.cancelCount, 1) + #expect(dataSource.cancelCount == 1) } - func testCancellationWorksBeforeIterating() async throws { + @Test func testCancellationWorksBeforeIterating() async throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -86,18 +86,18 @@ final class PostgresRowSequenceTests: XCTestCase { ) let rowSequence = stream.asyncSequence() - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) let dataRows: [DataRow] = (0..<128).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) var iterator: PostgresRowSequence.AsyncIterator? = rowSequence.makeAsyncIterator() iterator = nil - XCTAssertEqual(dataSource.cancelCount, 1) - XCTAssertNil(iterator, "Surpress warning") + #expect(dataSource.cancelCount == 1) + #expect(iterator == nil, "Surpress warning") } - func testDroppingTheSequenceCancelsTheSource() async throws { + @Test func testDroppingTheSequenceCancelsTheSource() throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -114,11 +114,11 @@ final class PostgresRowSequenceTests: XCTestCase { var rowSequence: PostgresRowSequence? = stream.asyncSequence() rowSequence = nil - XCTAssertEqual(dataSource.cancelCount, 1) - XCTAssertNil(rowSequence, "Surpress warning") + #expect(dataSource.cancelCount == 1) + #expect(rowSequence == nil, "Surpress warning") } - func testStreamBasedOnCompletedQuery() async throws { + @Test func testStreamBasedOnCompletedQuery() async throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -139,14 +139,14 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self), counter) + #expect(try row.decode(Int.self) == counter) counter += 1 } - XCTAssertEqual(dataSource.cancelCount, 0) + #expect(dataSource.cancelCount == 0) } - func testStreamIfInitializedWithAllData() async throws { + @Test func testStreamIfInitializedWithAllData() async throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -168,14 +168,14 @@ final class PostgresRowSequenceTests: XCTestCase { var counter = 0 for try await row in rowSequence { - XCTAssertEqual(try row.decode(Int.self), counter) + #expect(try row.decode(Int.self) == counter) counter += 1 } - XCTAssertEqual(dataSource.cancelCount, 0) + #expect(dataSource.cancelCount == 0) } - func testStreamIfInitializedWithError() async throws { + @Test func testStreamIfInitializedWithError() async throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -198,13 +198,13 @@ final class PostgresRowSequenceTests: XCTestCase { for try await _ in rowSequence { counter += 1 } - XCTFail("Expected that an error was thrown before.") + Issue.record("Expected that an error was thrown before.") } catch { - XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) + #expect(error as? PSQLError == .serverClosedConnection(underlying: nil)) } } - func testSucceedingRowContinuationsWorks() async throws { + @Test func testSucceedingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( @@ -227,17 +227,17 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self), 0) + #expect(try row1?.decode(Int.self) == 0) eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .success("SELECT 1")) } let row2 = try await rowIterator.next() - XCTAssertNil(row2) + #expect(row2 == nil) } - func testFailingRowContinuationsWorks() async throws { + @Test func testFailingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( @@ -260,7 +260,7 @@ final class PostgresRowSequenceTests: XCTestCase { } let row1 = try await rowIterator.next() - XCTAssertEqual(try row1?.decode(Int.self), 0) + #expect(try row1?.decode(Int.self) == 0) eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) @@ -268,13 +268,13 @@ final class PostgresRowSequenceTests: XCTestCase { do { _ = try await rowIterator.next() - XCTFail("Expected that an error was thrown before.") + Issue.record("Expected that an error was thrown before.") } catch { - XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil)) + #expect(error as? PSQLError == .serverClosedConnection(underlying: nil)) } } - func testAdaptiveRowBufferShrinksAndGrows() async throws { + @Test func testAdaptiveRowBufferShrinksAndGrows() async throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -294,20 +294,20 @@ final class PostgresRowSequenceTests: XCTestCase { let rowSequence = stream.asyncSequence() var rowIterator = rowSequence.makeAsyncIterator() - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) _ = try await rowIterator.next() // new buffer size will be target -> don't ask for more - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - XCTAssertEqual(dataSource.requestCount, 1) + #expect(dataSource.requestCount == 1) // if the buffer gets new rows so that it has equal or more than target (the target size // should be halved), however shrinking is only allowed AFTER the first extra rows were // received. let addDataRows1: [DataRow] = [[ByteBuffer(integer: Int64(0))]] stream.receive(addDataRows1) - XCTAssertEqual(dataSource.requestCount, 1) + #expect(dataSource.requestCount == 1) _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - XCTAssertEqual(dataSource.requestCount, 2) + #expect(dataSource.requestCount == 2) // if the buffer gets new rows so that it has equal or more than target (the target size // should be halved) @@ -316,30 +316,30 @@ final class PostgresRowSequenceTests: XCTestCase { _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget / 2) { _ = try await rowIterator.next() // Remove all rows until we are back at target - XCTAssertEqual(dataSource.requestCount, 2) + #expect(dataSource.requestCount == 2) } // if we remove another row we should trigger getting new rows. _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - XCTAssertEqual(dataSource.requestCount, 3) + #expect(dataSource.requestCount == 3) // remove all remaining rows... this will trigger a target size double for _ in 0..<(AdaptiveRowBuffer.defaultBufferTarget/2 - 1) { _ = try await rowIterator.next() // Remove all rows until we are back at target - XCTAssertEqual(dataSource.requestCount, 3) + #expect(dataSource.requestCount == 3) } let fillBufferDataRows: [DataRow] = (0.. don't ask for more - XCTAssertEqual(dataSource.requestCount, 3) + #expect(dataSource.requestCount == 3) _ = try await rowIterator.next() // new buffer will be (target - 1) -> ask for more - XCTAssertEqual(dataSource.requestCount, 4) + #expect(dataSource.requestCount == 4) } - func testAdaptiveRowShrinksToMin() async throws { + @Test func testAdaptiveRowShrinksToMin() async throws { let dataSource = MockRowDataSource() let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( @@ -362,9 +362,9 @@ final class PostgresRowSequenceTests: XCTestCase { var rowIterator = rowSequence.makeAsyncIterator() // shrinking the buffer is only allowed after the first extra rows were received - XCTAssertEqual(dataSource.requestCount, 0) + #expect(dataSource.requestCount == 0) _ = try await rowIterator.next() - XCTAssertEqual(dataSource.requestCount, 1) + #expect(dataSource.requestCount == 1) stream.receive([[ByteBuffer(integer: Int64(1))]]) @@ -373,10 +373,10 @@ final class PostgresRowSequenceTests: XCTestCase { while currentTarget > AdaptiveRowBuffer.defaultBufferMinimum { // the buffer is filled up to currentTarget at that point, if we remove one row and add // one row it should shrink - XCTAssertEqual(dataSource.requestCount, expectedRequestCount) + #expect(dataSource.requestCount == expectedRequestCount) _ = try await rowIterator.next() expectedRequestCount += 1 - XCTAssertEqual(dataSource.requestCount, expectedRequestCount) + #expect(dataSource.requestCount == expectedRequestCount) stream.receive([[ByteBuffer(integer: Int64(1))], [ByteBuffer(integer: Int64(1))]]) let newTarget = currentTarget / 2 @@ -385,16 +385,16 @@ final class PostgresRowSequenceTests: XCTestCase { // consume all messages that are to much. for _ in 0.. Date: Wed, 24 Sep 2025 16:28:24 +0200 Subject: [PATCH 291/292] PostgresNIOTests: PostgresConnectionTests uses Swift testing (#591) --- .../New/PostgresConnectionTests.swift | 945 +++++++++--------- 1 file changed, 481 insertions(+), 464 deletions(-) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index d0f8e2b0..b4658079 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -1,27 +1,29 @@ import NIOCore import NIOPosix import NIOEmbedded -import XCTest +import Testing import Logging @testable import PostgresNIO -class PostgresConnectionTests: XCTestCase { +@Suite struct PostgresConnectionTests { let logger = Logger(label: "PostgresConnectionTests") - func testConnectionFailure() { + @Test func testConnectionFailure() { // We start a local server and close it immediately to ensure that the port // number we try to connect to is not used by any other process. - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } - + let eventLoopGroup = NIOSingletons.posixEventLoopGroup + var tempChannel: Channel? - XCTAssertNoThrow(tempChannel = try ServerBootstrap(group: eventLoopGroup) - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait()) + #expect(throws: Never.self) { + tempChannel = try ServerBootstrap(group: eventLoopGroup) + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait() + } let maybePort = tempChannel?.localAddress?.port - XCTAssertNoThrow(try tempChannel?.close().wait()) + #expect(throws: Never.self) { try tempChannel?.close().wait() } guard let port = maybePort else { - return XCTFail("Could not get port number from temp started server") + Issue.record("Could not get port number from temp started server") + return } let config = PostgresConnection.Configuration( @@ -33,12 +35,14 @@ class PostgresConnectionTests: XCTestCase { var logger = Logger.psqlTest logger.logLevel = .trace - XCTAssertThrowsError(try PostgresConnection.connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger).wait()) { - XCTAssertTrue($0 is PSQLError) + #expect(throws: PSQLError.self) { + try PostgresConnection + .connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger) + .wait() } } - func testOptionsAreSentOnTheWire() async throws { + @Test func testOptionsAreSentOnTheWire() async throws { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) @@ -71,7 +75,7 @@ class PostgresConnectionTests: XCTestCase { async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + #expect(message == .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -80,269 +84,275 @@ class PostgresConnectionTests: XCTestCase { try await connection.close() } - func testSimpleListen() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testSimpleListen() async throws { + try await self.withAsyncTestingChannel { connection, channel in + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + for try await event in events { + #expect(event.payload == "wooohooo") + break + } + } - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - for try await event in events { - XCTAssertEqual(event.payload, "wooohooo") + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + + let unlistenMessage = try await channel.waitForUnpreparedRequest() + #expect(unlistenMessage.parse.query == #"UNLISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + switch await taskGroup.nextResult()! { + case .success: break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") } } - - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - - let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } } } - func testSimpleListenDoesNotUnlistenIfThereIsAnotherSubscriber() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testSimpleListenDoesNotUnlistenIfThereIsAnotherSubscriber() async throws { + try await self.withAsyncTestingChannel { connection, channel in - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - for try await event in events { - XCTAssertEqual(event.payload, "wooohooo") - break + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + for try await event in events { + #expect(event.payload == "wooohooo") + break + } } - } - taskGroup.addTask { - let events = try await connection.listen("foo") - var counter = 0 - loop: for try await event in events { - defer { counter += 1 } - switch counter { - case 0: - XCTAssertEqual(event.payload, "wooohooo") - case 1: - XCTAssertEqual(event.payload, "wooohooo2") - break loop - default: - XCTFail("Unexpected message: \(event)") + taskGroup.addTask { + let events = try await connection.listen("foo") + var counter = 0 + loop: for try await event in events { + defer { counter += 1 } + switch counter { + case 0: + #expect(event.payload == "wooohooo") + case 1: + #expect(event.payload == "wooohooo2") + break loop + default: + Issue.record("Unexpected message: \(event)") + } } } - } - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) - - let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } - } - } + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) - func testSimpleListenConnectionDrops() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - var iterator = events.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "wooohooo") - do { - _ = try await iterator.next() - XCTFail("Did not expect to not throw") - } catch { - logger.error("error", metadata: ["error": "\(error)"]) - } - } + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - struct MyWeirdError: Error {} - channel.pipeline.fireErrorCaught(MyWeirdError()) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) + + let unlistenMessage = try await channel.waitForUnpreparedRequest() + #expect(unlistenMessage.parse.query == #"UNLISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") + } } } } - func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in - for _ in 1...2 { + @Test func testSimpleListenConnectionDrops() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in taskGroup.addTask { - let rows = try await connection.query("SELECT 1;", logger: logger) - var iterator = rows.decode(Int.self).makeAsyncIterator() + let events = try await connection.listen("foo") + var iterator = events.makeAsyncIterator() let first = try await iterator.next() - XCTAssertEqual(first, 1) - let second = try await iterator.next() - XCTAssertNil(second) + #expect(first?.payload == "wooohooo") + do { + _ = try await iterator.next() + Issue.record("Did not expect to not throw") + } catch { + logger.error("error", metadata: ["error": "\(error)"]) + } } - } - for i in 0...1 { let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") - - if i == 0 { - taskGroup.addTask { - try await connection.closeGracefully() - } - } + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - let intDescription = RowDescription.Column( - name: "", - tableOID: 0, - columnAttributeNumber: 0, - dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary - ) - try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) - try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.noData) try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) - try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - } - let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(terminate, .terminate) - try await channel.closeFuture.get() - XCTAssertEqual(channel.isActive, false) + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + struct MyWeirdError: Error {} + channel.pipeline.fireErrorCaught(MyWeirdError()) - while let taskResult = await taskGroup.nextResult() { - switch taskResult { + switch await taskGroup.nextResult()! { case .success: break case .failure(let failure): - XCTFail("Unexpected error: \(failure)") + Issue.record("Unexpected error: \(failure)") } } } } - func testCloseClosesImmediatly() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { + try await self.withAsyncTestingChannel { connection, channel in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + let rows = try await connection.query("SELECT 1;", logger: logger) + var iterator = rows.decode(Int.self).makeAsyncIterator() + let first = try await iterator.next() + #expect(first == 1) + let second = try await iterator.next() + #expect(second == nil) + } + } + + for i in 0...1 { + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in - for _ in 1...2 { - taskGroup.addTask { - try await connection.query("SELECT 1;", logger: logger) + if i == 0 { + taskGroup.addTask { + try await connection.closeGracefully() + } + } + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + let intDescription = RowDescription.Column( + name: "", + tableOID: 0, + columnAttributeNumber: 0, + dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary + ) + try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + + let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + #expect(terminate == .terminate) + try await channel.closeFuture.get() + #expect(!channel.isActive) + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") + } } } + } + } - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + @Test func testCloseClosesImmediatly() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + try await connection.query("SELECT 1;", logger: logger) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") - async let close: () = connection.close() + async let close: () = connection.close() - try await channel.closeFuture.get() - XCTAssertEqual(channel.isActive, false) + try await channel.closeFuture.get() + #expect(!channel.isActive) - try await close + try await close - while let taskResult = await taskGroup.nextResult() { - switch taskResult { - case .success: - XCTFail("Expected queries to fail") - case .failure(let failure): - guard let error = failure as? PSQLError else { - return XCTFail("Unexpected error type: \(failure)") + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + Issue.record("Expected queries to fail") + case .failure(let failure): + guard let error = failure as? PSQLError else { + Issue.record("Unexpected error type: \(failure)") + return + } + #expect(error.code == .clientClosedConnection) } - XCTAssertEqual(error.code, .clientClosedConnection) } } } } - func testIfServerJustClosesTheErrorReflectsThat() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - let logger = self.logger + @Test func testIfServerJustClosesTheErrorReflectsThat() async throws { + try await self.withAsyncTestingChannel { connection, channel in + let logger = self.logger - async let response = try await connection.query("SELECT 1;", logger: logger) + async let response = try await connection.query("SELECT 1;", logger: logger) - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") - try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } - try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } - do { - _ = try await response - XCTFail("Expected to throw") - } catch { - XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) - } + do { + _ = try await response + Issue.record("Expected to throw") + } catch { + #expect((error as? PSQLError)?.code == .serverClosedConnection) + } - // retry on same connection + // retry on same connection - do { - _ = try await connection.query("SELECT 1;", logger: self.logger) - XCTFail("Expected to throw") - } catch { - XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + do { + _ = try await connection.query("SELECT 1;", logger: self.logger) + Issue.record("Expected to throw") + } catch { + #expect((error as? PSQLError)?.code == .serverClosedConnection) + } } } @@ -363,282 +373,287 @@ class PostgresConnectionTests: XCTestCase { } } - func testPreparedStatement() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) + @Test func testPreparedStatement() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - let preparedRequest = try await channel.waitForPreparedRequest() - XCTAssertEqual(preparedRequest.bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) - XCTAssertEqual(preparedRequest.bind.parameters.count, 1) - XCTAssertEqual(preparedRequest.bind.resultColumnFormats, [.binary]) + let preparedRequest = try await channel.waitForPreparedRequest() + #expect(preparedRequest.bind.preparedStatementName == String(reflecting: TestPrepareStatement.self)) + #expect(preparedRequest.bind.parameters.count == 1) + #expect(preparedRequest.bind.resultColumnFormats == [.binary]) - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + } } } - func testWeDontCrashOnUnexpectedChannelEvents() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testWeDontCrashOnUnexpectedChannelEvents() async throws { + try await self.withAsyncTestingChannel { connection, channel in - enum MyEvent { - case pleaseDontCrash + enum MyEvent { + case pleaseDontCrash + } + channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + try await connection.close() } - channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) - try await connection.close() } - func testSerialExecutionOfSamePreparedStatement() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Send one first and wait to send the other request until preparation is complete - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) + @Test func testSerialExecutionOfSamePreparedStatement() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - let preparedRequest1 = try await channel.waitForPreparedRequest() - var buffer = preparedRequest1.bind.parameters[0]! - let parameter1 = buffer.readString(length: buffer.readableBytes)! - XCTAssertEqual(parameter1, "active") - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + #expect(parameter1 == "active") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) - // Now that the statement has been prepared and executed, send another request that will only get executed - // without preparation - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) + // Now that the statement has been prepared and executed, send another request that will only get executed + // without preparation + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - let preparedRequest2 = try await channel.waitForPreparedRequest() - buffer = preparedRequest2.bind.parameters[0]! - let parameter2 = buffer.readString(length: buffer.readableBytes)! - XCTAssertEqual(parameter2, "idle") - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) - // Ensure we received and responded to both the requests - let parameters = [parameter1, parameter2] - XCTAssert(parameters.contains("active")) - XCTAssert(parameters.contains("idle")) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + #expect(parameter2 == "idle") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + #expect(parameters.contains("active")) + #expect(parameters.contains("idle")) + } } } - func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Let them race to tests that requests and responses aren't mixed up - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database_active", database) + @Test func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Let them race to tests that requests and responses aren't mixed up + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database_active" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database_idle", database) + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database_idle" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - // The channel deduplicates prepare requests, we're going to see only one of them - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) + // The channel deduplicates prepare requests, we're going to see only one of them + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - // Now both the tasks have their statements prepared. - // We should see both of their execute requests coming in, the order is nondeterministic - let preparedRequest1 = try await channel.waitForPreparedRequest() - var buffer = preparedRequest1.bind.parameters[0]! - let parameter1 = buffer.readString(length: buffer.readableBytes)! - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database_\(parameter1)"] - ], - commandTag: TestPrepareStatement.sql - ) - let preparedRequest2 = try await channel.waitForPreparedRequest() - buffer = preparedRequest2.bind.parameters[0]! - let parameter2 = buffer.readString(length: buffer.readableBytes)! - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database_\(parameter2)"] - ], - commandTag: TestPrepareStatement.sql - ) - // Ensure we received and responded to both the requests - let parameters = [parameter1, parameter2] - XCTAssert(parameters.contains("active")) - XCTAssert(parameters.contains("idle")) + // Now both the tasks have their statements prepared. + // We should see both of their execute requests coming in, the order is nondeterministic + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter1)"] + ], + commandTag: TestPrepareStatement.sql + ) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter2)"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + #expect(parameters.contains("active")) + #expect(parameters.contains("idle")) + } } } - func testStatementPreparationFailure() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Send one first and wait to send the other request until preparation is complete - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - do { - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Was supposed to fail") - } catch { - XCTAssert(error is PSQLError) + @Test func testStatementPreparationFailure() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + Issue.record("Was supposed to fail") + } catch { + #expect(error is PSQLError) + } } - } - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - // Respond with an error taking care to return a SQLSTATE that isn't - // going to get the connection closed. - try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ - .sqlState : "26000" // invalid_sql_statement_name - ]))) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - try await channel.testingEventLoop.executeInContext { channel.read() } - - - // Send another requests with the same prepared statement, which should fail straight - // away without any interaction with the server - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - do { - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Was supposed to fail") - } catch { - XCTAssert(error is PSQLError) + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + // Respond with an error taking care to return a SQLSTATE that isn't + // going to get the connection closed. + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .sqlState : "26000" // invalid_sql_statement_name + ]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } + + + // Send another requests with the same prepared statement, which should fail straight + // away without any interaction with the server + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + Issue.record("Was supposed to fail") + } catch { + #expect(error is PSQLError) + } } } } } - func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { + func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) @@ -656,18 +671,20 @@ class PostgresConnectionTests: XCTestCase { let logger = self.logger async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) + #expect(message == .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) let connection = try await connectionPromise - self.addTeardownBlock { - try await connection.close() + do { + try await body(connection, channel) + } catch { + } - return (connection, channel) + try await connection.close() } } From 4795a0b0762c5cca843a720969c9dfee378e5120 Mon Sep 17 00:00:00 2001 From: zunda <47569369+zunda-pixel@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:17:57 +0900 Subject: [PATCH 292/292] Allow swift-crypto version 4+ (#592) Update Package.swift --- Package.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index 673a4bb2..ae0e8a5d 100644 --- a/Package.swift +++ b/Package.swift @@ -30,7 +30,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), - .package(url: "https://github.com/apple/swift-crypto.git", "3.9.0" ..< "4.0.0"), + .package(url: "https://github.com/apple/swift-crypto.git", "3.9.0" ..< "5.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.5.0"),